From caa598c0e4fe1a3b9ccd055d7cd2210930eaef2b Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 21:10:51 -0400 Subject: [PATCH 01/83] [fix-tools]Bugfixes and refactoring --- tools/C-logging.rkt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/C-logging.rkt b/tools/C-logging.rkt index 2150a6c..33780ec 100644 --- a/tools/C-logging.rkt +++ b/tools/C-logging.rkt @@ -50,7 +50,9 @@ ((eq? data 'reset) (loop 0.0 0.0 0.0 0.0 0.0 0 sampling-frequency)) ((and data (= sampling-count 1)) - (print-average (/ (sum-all d0 d1 d2 d3 d4 data) (* 6 (ρ (product (shape data))))) count) + (print-average (/ (ρ (sum-all d0 d1 d2 d3 d4 data)) + (* 6 (ρ (product (shape data))))) + count) (loop d1 d2 d3 d4 data (add1 count) sampling-frequency)) (data (loop d1 d2 d3 d4 data (add1 count) (sub1 sampling-count))) From 1b1f4e40f08d45a6688a0d64379f6ed97cfda814 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sun, 28 Apr 2024 10:22:12 -0400 Subject: [PATCH 02/83] [fix-make]Fix config loading --- Makefile | 4 +++- impl-loader.rkt | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 2b8078d..7ceafaa 100644 --- a/Makefile +++ b/Makefile @@ -244,10 +244,12 @@ build: $(SOURCES) # Test it all test: @ echo "Running tests ..." &&\ + export MALT_PREFERENCES="$(shell pwd)/local.cfg";\ $(RACO) test $(TEST_FLAGS) $(SOURCES) one: - -@ $(RACO) make $(ARG) && $(RACO) test $(ARG) + -@ export MALT_PREFERENCES="$(shell pwd)/local.cfg";\ + $(RACO) make $(ARG) && $(RACO) test $(ARG) clean: find . -name 'compiled' | xargs -I% rm -rf % diff --git a/impl-loader.rkt b/impl-loader.rkt index 4fc5f1b..fdbea08 100644 --- a/impl-loader.rkt +++ b/impl-loader.rkt @@ -22,7 +22,7 @@ (define init-settings (λ () - (settings (make-hash (read-preferences "local.cfg"))))) + (settings (make-hash (read-preferences (or (getenv "MALT_PREFERENCES") "local.cfg")))))) ;;-------------------------------- ;; Config params so far @@ -37,6 +37,8 @@ `((tensor-implementation learner))) (when (not (settings)) - (init-settings)) + (init-settings) + (println "settings=") + (pretty-print (settings))) (provide tensor-implementation) From c2e9d386540922d0d0c458723707bc1887da576f Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:49:55 -0400 Subject: [PATCH 03/83] =?UTF-8?q?[fix-learner]use=20one=20kernel=20in=20ex?= =?UTF-8?q?t2-=E2=88=87=20and=20accelerate=20all=20ext*=20function=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- learner.rkt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/learner.rkt b/learner.rkt index 0068aa4..fd191f7 100644 --- a/learner.rkt +++ b/learner.rkt @@ -13,6 +13,8 @@ (error "ext2-∇ is not provided by the learner implementation"))) (provide + tolerance + len ref refr tref tlen tmap list->tensor tensor build-tensor From 61b81d70b2f4ba884b804d527ef8c7a4c2d1fe40 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:26:58 -0400 Subject: [PATCH 04/83] [fix-learner]Add zeroes as a primitive --- learner.rkt | 2 +- learner/ext-ops.rkt | 2 +- learner/ext-ops/J-nd-ops.rkt | 5 ++++- learner/no-duals-no-overrides.rkt | 2 +- learner/no-duals.rkt | 2 +- learner/no-overrides.rkt | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/learner.rkt b/learner.rkt index fd191f7..e49333d 100644 --- a/learner.rkt +++ b/learner.rkt @@ -41,7 +41,7 @@ rectify flatten concat concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/learner/ext-ops.rkt b/learner/ext-ops.rkt index 5db1276..d1c8413 100644 --- a/learner/ext-ops.rkt +++ b/learner/ext-ops.rkt @@ -21,7 +21,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - sqrt-ρ sqr-ρ) + sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/learner/ext-ops/J-nd-ops.rkt b/learner/ext-ops/J-nd-ops.rkt index 107e896..fff3489 100644 --- a/learner/ext-ops/J-nd-ops.rkt +++ b/learner/ext-ops/J-nd-ops.rkt @@ -35,9 +35,12 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1 (λ (_) 0.0) 0)) + (provide +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (define *-2-1-ρ (ext2 *-ρ 2 1)) diff --git a/learner/no-duals-no-overrides.rkt b/learner/no-duals-no-overrides.rkt index 5b40500..ee52a1b 100644 --- a/learner/no-duals-no-overrides.rkt +++ b/learner/no-duals-no-overrides.rkt @@ -18,7 +18,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/learner/no-duals.rkt b/learner/no-duals.rkt index b20d585..361dd22 100644 --- a/learner/no-duals.rkt +++ b/learner/no-duals.rkt @@ -21,7 +21,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/learner/no-overrides.rkt b/learner/no-overrides.rkt index d2d521f..13640e5 100644 --- a/learner/no-overrides.rkt +++ b/learner/no-overrides.rkt @@ -33,7 +33,7 @@ (rename-out (concat-n d-concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ From 2a3cccdc9e01e773d13e504e202246b4675ab6c2 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:52:17 -0400 Subject: [PATCH 05/83] =?UTF-8?q?[fix-nested]use=20one=20kernel=20in=20ext?= =?UTF-8?q?2-=E2=88=87=20and=20accelerate=20all=20ext*=20function=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- nested-tensors.rkt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nested-tensors.rkt b/nested-tensors.rkt index 277061c..571c7cc 100644 --- a/nested-tensors.rkt +++ b/nested-tensors.rkt @@ -8,6 +8,8 @@ (require "nested-tensors/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen tmap list->tensor tensor build-tensor From 70b38d06dba1061fdf03c1c59d0ef7d3276f79e7 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:27:41 -0400 Subject: [PATCH 06/83] [fix-nested]Add zeroes as a primitive --- nested-tensors.rkt | 2 +- nested-tensors/ext-ops.rkt | 2 +- nested-tensors/ext-ops/A-scalar-ops.rkt | 5 ++++- nested-tensors/no-duals-no-overrides.rkt | 2 +- nested-tensors/no-duals.rkt | 2 +- nested-tensors/no-overrides.rkt | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/nested-tensors.rkt b/nested-tensors.rkt index 571c7cc..b1db486 100644 --- a/nested-tensors.rkt +++ b/nested-tensors.rkt @@ -32,7 +32,7 @@ (d-flatten flatten) (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ diff --git a/nested-tensors/ext-ops.rkt b/nested-tensors/ext-ops.rkt index 41428e1..473f37f 100644 --- a/nested-tensors/ext-ops.rkt +++ b/nested-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/nested-tensors/ext-ops/A-scalar-ops.rkt b/nested-tensors/ext-ops/A-scalar-ops.rkt index 3e98bd1..bcd1596 100644 --- a/nested-tensors/ext-ops/A-scalar-ops.rkt +++ b/nested-tensors/ext-ops/A-scalar-ops.rkt @@ -119,6 +119,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide d+ d- d* d/ @@ -130,4 +133,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/nested-tensors/no-duals-no-overrides.rkt b/nested-tensors/no-duals-no-overrides.rkt index c3bbad0..ac909a5 100644 --- a/nested-tensors/no-duals-no-overrides.rkt +++ b/nested-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ diff --git a/nested-tensors/no-duals.rkt b/nested-tensors/no-duals.rkt index baa635a..d9dcf7d 100644 --- a/nested-tensors/no-duals.rkt +++ b/nested-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) diff --git a/nested-tensors/no-overrides.rkt b/nested-tensors/no-overrides.rkt index 3d39c8b..417a96e 100644 --- a/nested-tensors/no-overrides.rkt +++ b/nested-tensors/no-overrides.rkt @@ -35,7 +35,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ From 5a5946d9c6cdee064034a7950485244b134c96e3 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 17:49:33 -0400 Subject: [PATCH 07/83] [fix-flat]Generate indices like a GPU kernel --- flat-tensors/tensors/test/test-D-extend.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flat-tensors/tensors/test/test-D-extend.rkt b/flat-tensors/tensors/test/test-D-extend.rkt index b6026f1..34d9618 100644 --- a/flat-tensors/tensors/test/test-D-extend.rkt +++ b/flat-tensors/tensors/test/test-D-extend.rkt @@ -43,8 +43,8 @@ (define dup (ext1-ρ dup-f 1 dup-shape-f)) (check-equal? (flat-store (dup t0)) - #(0 2 4 6 0 2 4 6 - 8 10 12 14 8 10 12 14 + #(0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 16 18 20 22 16 18 20 22 24 26 28 30 24 26 28 30 32 34 36 38 32 34 36 38 From 3ef3fb16a9461f7fc367e6b84a9e08859b031fef Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 18:43:27 -0400 Subject: [PATCH 08/83] =?UTF-8?q?[fix-flat]Use=20z=20offset=20in=20flat-ex?= =?UTF-8?q?t1-=E2=88=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flat-tensors/tensors/D-extend.rkt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index bc698fa..d2bde38 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -209,11 +209,12 @@ (sf-z (shape-fn sf0)) (stride-z (size-of sf-z)) (vz (flat-store z)) + (offz (flat-offset z)) (g0 (new-vec size0 0.0))) (for ([iz (in-range 0 size-z stride-z)] [i0 (in-range off0 (+ off0 size0) stride0)]) - (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z)) (flat s0 g0 0)))) (define flat-ext2-ρ From 5e47463addbb5578cc40aa7b73892dcd278284ee Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:27:15 -0400 Subject: [PATCH 09/83] [fix-flat]Minor fixes and a bugged accelerated runtime --- flat-tensors/tensors/C-tensor-ops.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flat-tensors/tensors/C-tensor-ops.rkt b/flat-tensors/tensors/C-tensor-ops.rkt index f9056b4..cf32d23 100644 --- a/flat-tensors/tensors/C-tensor-ops.rkt +++ b/flat-tensors/tensors/C-tensor-ops.rkt @@ -26,7 +26,7 @@ (cond ((= (size-of s) (flat-size t)) (flat s (flat-store t) (flat-offset t))) - (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + (else (error 'tensor-reshape "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) (include "test/test-C-tensor-ops.rkt") From 478a516491c93a5b3796af266356cfdf64690214 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:55:29 -0400 Subject: [PATCH 10/83] =?UTF-8?q?[fix-flat]use=20one=20kernel=20in=20ext2-?= =?UTF-8?q?=E2=88=87=20and=20accelerate=20all=20ext*=20function=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flat-tensors.rkt | 2 ++ flat-tensors/ext-ops/E-argmax.rkt | 2 +- flat-tensors/ext-ops/F-max.rkt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/flat-tensors.rkt b/flat-tensors.rkt index 3685ce3..5464e8b 100644 --- a/flat-tensors.rkt +++ b/flat-tensors.rkt @@ -8,6 +8,8 @@ (require "flat-tensors/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen list->tensor tensor build-tensor diff --git a/flat-tensors/ext-ops/E-argmax.rkt b/flat-tensors/ext-ops/E-argmax.rkt index ad4aefe..3130843 100644 --- a/flat-tensors/ext-ops/E-argmax.rkt +++ b/flat-tensors/ext-ops/E-argmax.rkt @@ -7,7 +7,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vector-set! v-out i-out - (for/fold ([max 0.0] + (for/fold ([max -inf.0] [max-i -1] #:result max-i) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vector-ref v0 i))) diff --git a/flat-tensors/ext-ops/F-max.rkt b/flat-tensors/ext-ops/F-max.rkt index 5a54bd5..77e3d26 100644 --- a/flat-tensors/ext-ops/F-max.rkt +++ b/flat-tensors/ext-ops/F-max.rkt @@ -7,7 +7,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vector-set! v-out i-out - (for/fold ([max 0.0]) + (for/fold ([max -inf.0]) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vector-ref v0 i))) (cond From 1213785ba5a210e6975bc7721aaeaa0729784c97 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:28:32 -0400 Subject: [PATCH 11/83] [fix-flat]Add zeroes as a primitive --- flat-tensors.rkt | 2 +- flat-tensors/ext-ops.rkt | 2 +- flat-tensors/ext-ops/A-scalar-ops.rkt | 5 ++++- flat-tensors/no-duals-no-overrides.rkt | 2 +- flat-tensors/no-duals.rkt | 2 +- flat-tensors/no-overrides.rkt | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/flat-tensors.rkt b/flat-tensors.rkt index 5464e8b..e4fae31 100644 --- a/flat-tensors.rkt +++ b/flat-tensors.rkt @@ -33,7 +33,7 @@ (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/flat-tensors/ext-ops.rkt b/flat-tensors/ext-ops.rkt index 83af7de..fc223f5 100644 --- a/flat-tensors/ext-ops.rkt +++ b/flat-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/flat-tensors/ext-ops/A-scalar-ops.rkt b/flat-tensors/ext-ops/A-scalar-ops.rkt index b251e80..9096209 100644 --- a/flat-tensors/ext-ops/A-scalar-ops.rkt +++ b/flat-tensors/ext-ops/A-scalar-ops.rkt @@ -121,6 +121,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 @@ -132,4 +135,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/flat-tensors/no-duals-no-overrides.rkt b/flat-tensors/no-duals-no-overrides.rkt index ac07a7a..07ca22e 100644 --- a/flat-tensors/no-duals-no-overrides.rkt +++ b/flat-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ concat-ρ concat-n-ρ flatten-ρ diff --git a/flat-tensors/no-duals.rkt b/flat-tensors/no-duals.rkt index 927c8c7..cd1bcaf 100644 --- a/flat-tensors/no-duals.rkt +++ b/flat-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/flat-tensors/no-overrides.rkt b/flat-tensors/no-overrides.rkt index 35dcbdd..05844b7 100644 --- a/flat-tensors/no-overrides.rkt +++ b/flat-tensors/no-overrides.rkt @@ -34,7 +34,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ From 845ce877423e62cb141aa4c7f1979caec520991d Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:58:11 -0400 Subject: [PATCH 12/83] [fix-flat]Fix apply-*-2 --- flat-tensors/autodiff/B-prims.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flat-tensors/autodiff/B-prims.rkt b/flat-tensors/autodiff/B-prims.rkt index ccbe516..1bb8d9e 100644 --- a/flat-tensors/autodiff/B-prims.rkt +++ b/flat-tensors/autodiff/B-prims.rkt @@ -151,7 +151,7 @@ (λ (ρ-fn ra rb shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) @@ -175,7 +175,7 @@ (λ (∇-fn ra rb z shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) From 216965dc6fb9889f5e830de0aa3b4712bcd839cd Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 17 Jul 2024 00:15:32 -0400 Subject: [PATCH 13/83] Remove zeroes defined in malted --- malted/E-gd-common.rkt | 3 --- 1 file changed, 3 deletions(-) diff --git a/malted/E-gd-common.rkt b/malted/E-gd-common.rkt index 35a9549..84a75f3 100644 --- a/malted/E-gd-common.rkt +++ b/malted/E-gd-common.rkt @@ -3,9 +3,6 @@ ;; Extended operators are non-dualized (require "../base-no-duals.rkt") -(define zeroes - (ext1-ρ (λ (_) 0.0) 0)) - (define smooth (λ (decay-rate average g) (+ (* decay-rate average) From c5e8efb88dc6eb60c390feac7455b9b097702784 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 20:04:00 -0400 Subject: [PATCH 14/83] [fix-malted]Use check-dual in malted/test/test-E-gd-common --- malted/test/test-E-gd-common.rkt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/malted/test/test-E-gd-common.rkt b/malted/test/test-E-gd-common.rkt index 092a062..f105a4b 100644 --- a/malted/test/test-E-gd-common.rkt +++ b/malted/test/test-E-gd-common.rkt @@ -1,12 +1,13 @@ (module+ test (require rackunit) + (require "../impl.rkt") - (check-equal? (zeroes (tensor 1 2 3)) + (check-dual-equal? (zeroes (tensor 1 2 3)) (tensor 0.0 0.0 0.0)) - (check-equal? (smooth 0.9 31 -8) 27.1) + (check-dual-equal? (smooth 0.9 31 -8) 27.1) - (check-equal? (smooth 0.9 27.1 4) 24.79) + (check-dual-equal? (smooth 0.9 27.1 4) 24.79) (with-hypers ((mu 0.5) (beta 0.3)) - (check-equal? (+ mu beta) 0.8))) + (check-dual-equal? (+ mu beta) 0.8))) From f5323dfac5bc07015d7181131b80df2dd841ce27 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 21:42:19 -0400 Subject: [PATCH 15/83] [fix-malted]Tidy up TODOs --- malted/I-adam.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/malted/I-adam.rkt b/malted/I-adam.rkt index 818e300..bb72940 100644 --- a/malted/I-adam.rkt +++ b/malted/I-adam.rkt @@ -19,7 +19,7 @@ (let ((r (smooth beta (ref pa 2) (sqr g)))) (let ((alpha-hat (/ alpha (+ (sqrt r) epsilon))) (v (smooth mu (ref pa 1) g))) - (list (- (ref pa 0) (* alpha-hat v)) v r))))) + (list (- (ref pa 0) (* alpha-hat v)) v r))))) (define adam-gradient-descent (gradient-descent From 38fe2fd2e110a7533af1e5644fb7d49ee8f7008e Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 13 Jul 2024 09:23:44 -0400 Subject: [PATCH 16/83] [fix-malted]Use check-within for malted tests --- malted/test/test-D-gradient-descent.rkt | 5 +++-- malted/test/test-F-naked.rkt | 5 +++-- malted/test/test-G-velocity.rkt | 5 +++-- malted/test/test-H-rms.rkt | 6 +++--- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/malted/test/test-D-gradient-descent.rkt b/malted/test/test-D-gradient-descent.rkt index b21193e..554cb57 100644 --- a/malted/test/test-D-gradient-descent.rkt +++ b/malted/test/test-D-gradient-descent.rkt @@ -12,8 +12,9 @@ (define naked-gd (gradient-descent id id (λ (theta g) (--ρ theta (*-ρ g alpha))))) - (check-equal? + (check-within (with-hypers ((revs 500) (alpha 0.01)) (naked-gd obj (list 3.0))) - '(29.998892352401082))) + '(29.998892352401082) + (tolerance))) diff --git a/malted/test/test-F-naked.rkt b/malted/test/test-F-naked.rkt index 2f0ebf7..d59ad59 100644 --- a/malted/test/test-F-naked.rkt +++ b/malted/test/test-F-naked.rkt @@ -5,8 +5,9 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-equal? + (check-within (with-hypers ((revs 400) (alpha 0.01)) (naked-gradient-descent obj (list 3.0))) - '(29.991647931623252))) + '(29.991647931623252) + (tolerance))) diff --git a/malted/test/test-G-velocity.rkt b/malted/test/test-G-velocity.rkt index 8502003..154de9d 100644 --- a/malted/test/test-G-velocity.rkt +++ b/malted/test/test-G-velocity.rkt @@ -5,9 +5,10 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-dual-equal? + (check-within (with-hypers ((revs 70) (alpha 0.01) (mu 0.9)) (velocity-gradient-descent obj (list 3.0))) - '(30.686162582787535))) + '(30.686162582787535) + (tolerance))) diff --git a/malted/test/test-H-rms.rkt b/malted/test/test-H-rms.rkt index 750352b..cb7e84b 100644 --- a/malted/test/test-H-rms.rkt +++ b/malted/test/test-H-rms.rkt @@ -5,10 +5,10 @@ (define obj (λ (theta) (sqr (- 30 (ref theta 0))))) - (check-dual-equal? + (check-within (with-hypers ((revs 170) (alpha 0.1) (beta 0.999)) (rms-gradient-descent obj (list 3.0))) - '(29.990436450964964)) - ) + '(29.990436450964964) + (tolerance))) From 4f72470639b771a5c02a379f079270603d6caf5f Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 31 Jan 2024 07:49:03 -0500 Subject: [PATCH 17/83] [add-unif]Add uniform-tensors representation using f32vector --- Makefile | 49 +++ impl-no-duals.rkt | 2 + impl-no-overrides.rkt | 2 + impl.rkt | 2 + set-impl.rkt | 5 +- uniform-tensors.rkt | 45 ++ uniform-tensors/autodiff.rkt | 23 ++ uniform-tensors/autodiff/A-autodiff.rkt | 120 ++++++ uniform-tensors/autodiff/B-prims.rkt | 221 ++++++++++ .../autodiff/C-dualized-tensor-ops.rkt | 51 +++ uniform-tensors/autodiff/D-test-helpers.rkt | 47 +++ uniform-tensors/autodiff/E-print.rkt | 87 ++++ .../autodiff/test/test-A-autodiff.rkt | 15 + .../autodiff/test/test-E-print.rkt | 71 ++++ uniform-tensors/ext-impl.rkt | 28 ++ uniform-tensors/ext-ops.rkt | 40 ++ uniform-tensors/ext-ops/A-scalar-ops.rkt | 135 ++++++ uniform-tensors/ext-ops/B-comparators.rkt | 85 ++++ uniform-tensors/ext-ops/C-star-2-1.rkt | 44 ++ uniform-tensors/ext-ops/D-sum.rkt | 68 +++ uniform-tensors/ext-ops/E-argmax.rkt | 41 ++ uniform-tensors/ext-ops/F-max.rkt | 49 +++ uniform-tensors/ext-ops/G-correlate.rkt | 95 +++++ uniform-tensors/ext-ops/I-flatten.rkt | 31 ++ uniform-tensors/ext-ops/K-concat.rkt | 76 ++++ .../ext-ops/test/test-A-scalar-ops.rkt | 122 ++++++ .../ext-ops/test/test-B-comparators.rkt | 12 + .../ext-ops/test/test-C-star-2-1.rkt | 24 ++ uniform-tensors/ext-ops/test/test-D-sum.rkt | 58 +++ .../ext-ops/test/test-E-argmax.rkt | 17 + uniform-tensors/ext-ops/test/test-F-max.rkt | 10 + .../ext-ops/test/test-G-correlate.rkt | 118 ++++++ .../ext-ops/test/test-I-flatten.rkt | 13 + .../ext-ops/test/test-K-concat.rkt | 126 ++++++ uniform-tensors/no-duals-no-overrides.rkt | 29 ++ uniform-tensors/no-duals.rkt | 29 ++ uniform-tensors/no-overrides.rkt | 43 ++ uniform-tensors/tensors.rkt | 22 + uniform-tensors/tensors/0-vectors.rkt | 85 ++++ uniform-tensors/tensors/1-flats.rkt | 73 ++++ uniform-tensors/tensors/A-equality.rkt | 75 ++++ uniform-tensors/tensors/B-tensor-basics.rkt | 184 +++++++++ uniform-tensors/tensors/C-tensor-ops.rkt | 35 ++ uniform-tensors/tensors/D-extend.rkt | 391 ++++++++++++++++++ .../tensors/test/test-A-equality.rkt | 84 ++++ .../tensors/test/test-B-tensor-basics.rkt | 33 ++ .../tensors/test/test-C-tensor-ops.rkt | 63 +++ .../tensors/test/test-D-extend.rkt | 234 +++++++++++ 48 files changed, 3311 insertions(+), 1 deletion(-) create mode 100644 uniform-tensors.rkt create mode 100644 uniform-tensors/autodiff.rkt create mode 100644 uniform-tensors/autodiff/A-autodiff.rkt create mode 100644 uniform-tensors/autodiff/B-prims.rkt create mode 100644 uniform-tensors/autodiff/C-dualized-tensor-ops.rkt create mode 100644 uniform-tensors/autodiff/D-test-helpers.rkt create mode 100644 uniform-tensors/autodiff/E-print.rkt create mode 100644 uniform-tensors/autodiff/test/test-A-autodiff.rkt create mode 100644 uniform-tensors/autodiff/test/test-E-print.rkt create mode 100644 uniform-tensors/ext-impl.rkt create mode 100644 uniform-tensors/ext-ops.rkt create mode 100644 uniform-tensors/ext-ops/A-scalar-ops.rkt create mode 100644 uniform-tensors/ext-ops/B-comparators.rkt create mode 100644 uniform-tensors/ext-ops/C-star-2-1.rkt create mode 100644 uniform-tensors/ext-ops/D-sum.rkt create mode 100644 uniform-tensors/ext-ops/E-argmax.rkt create mode 100644 uniform-tensors/ext-ops/F-max.rkt create mode 100644 uniform-tensors/ext-ops/G-correlate.rkt create mode 100644 uniform-tensors/ext-ops/I-flatten.rkt create mode 100644 uniform-tensors/ext-ops/K-concat.rkt create mode 100644 uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt create mode 100644 uniform-tensors/ext-ops/test/test-B-comparators.rkt create mode 100644 uniform-tensors/ext-ops/test/test-C-star-2-1.rkt create mode 100644 uniform-tensors/ext-ops/test/test-D-sum.rkt create mode 100644 uniform-tensors/ext-ops/test/test-E-argmax.rkt create mode 100644 uniform-tensors/ext-ops/test/test-F-max.rkt create mode 100644 uniform-tensors/ext-ops/test/test-G-correlate.rkt create mode 100644 uniform-tensors/ext-ops/test/test-I-flatten.rkt create mode 100644 uniform-tensors/ext-ops/test/test-K-concat.rkt create mode 100644 uniform-tensors/no-duals-no-overrides.rkt create mode 100644 uniform-tensors/no-duals.rkt create mode 100644 uniform-tensors/no-overrides.rkt create mode 100644 uniform-tensors/tensors.rkt create mode 100644 uniform-tensors/tensors/0-vectors.rkt create mode 100644 uniform-tensors/tensors/1-flats.rkt create mode 100644 uniform-tensors/tensors/A-equality.rkt create mode 100644 uniform-tensors/tensors/B-tensor-basics.rkt create mode 100644 uniform-tensors/tensors/C-tensor-ops.rkt create mode 100644 uniform-tensors/tensors/D-extend.rkt create mode 100644 uniform-tensors/tensors/test/test-A-equality.rkt create mode 100644 uniform-tensors/tensors/test/test-B-tensor-basics.rkt create mode 100644 uniform-tensors/tensors/test/test-C-tensor-ops.rkt create mode 100644 uniform-tensors/tensors/test/test-D-extend.rkt diff --git a/Makefile b/Makefile index 7ceafaa..fb1aade 100644 --- a/Makefile +++ b/Makefile @@ -12,6 +12,7 @@ TEST_FLAGS=-q LEARNER_DIR=learner FLAT_DIR=flat-tensors +UNIFORM_DIR=uniform-tensors NESTED_DIR=nested-tensors TOOLS_DIR=tools MALTED_DIR=malted @@ -109,6 +110,54 @@ FLAT_SOURCES=$(FLAT_TENSORS_SOURCES)\ $(FLAT_EXT_OPS_SOURCES)\ $(FLAT_LOADERS) +# uniform-tensors +UNIFORM_TENSORS_DIR=$(UNIFORM_DIR)/tensors +UNIFORM_AUTODIFF_DIR=$(UNIFORM_DIR)/autodiff +UNIFORM_EXT_OPS_DIR=$(UNIFORM_DIR)/ext-ops + +UNIFORM_TENSORS_SOURCES=\ + $(UNIFORM_TENSORS_DIR)/0-vectors.rkt\ + $(UNIFORM_TENSORS_DIR)/1-flats.rkt\ + $(UNIFORM_TENSORS_DIR)/A-equality.rkt\ + $(UNIFORM_TENSORS_DIR)/B-tensor-basics.rkt\ + $(UNIFORM_TENSORS_DIR)/C-tensor-ops.rkt\ + $(UNIFORM_TENSORS_DIR)/D-extend.rkt\ + $(UNIFORM_DIR)/tensors.rkt + +UNIFORM_AUTODIFF_SOURCES=\ + $(UNIFORM_AUTODIFF_DIR)/A-autodiff.rkt\ + $(UNIFORM_AUTODIFF_DIR)/B-prims.rkt\ + $(UNIFORM_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(UNIFORM_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(UNIFORM_AUTODIFF_DIR)/E-print.rkt\ + $(UNIFORM_DIR)/autodiff.rkt + +UNIFORM_EXT_OPS_SOURCES=\ + $(UNIFORM_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(UNIFORM_EXT_OPS_DIR)/B-comparators.rkt\ + $(UNIFORM_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(UNIFORM_EXT_OPS_DIR)/D-sum.rkt\ + $(UNIFORM_EXT_OPS_DIR)/E-argmax.rkt\ + $(UNIFORM_EXT_OPS_DIR)/F-max.rkt\ + $(UNIFORM_EXT_OPS_DIR)/G-correlate.rkt\ + $(UNIFORM_EXT_OPS_DIR)/I-flatten.rkt\ + $(UNIFORM_EXT_OPS_DIR)/K-concat.rkt\ + $(UNIFORM_DIR)/ext-ops.rkt + +UNIFORM_LOADERS=\ + $(UNIFORM_DIR)/no-duals-no-overrides.rkt\ + $(UNIFORM_DIR)/no-duals.rkt\ + $(UNIFORM_DIR)/no-overrides.rkt\ + $(UNIFORM_DIR)/tensors.rkt\ + $(UNIFORM_DIR)/autodiff.rkt\ + $(UNIFORM_DIR)/ext-impl.rkt\ + $(UNIFORM_DIR)/ext-ops.rkt + +UNIFORM_SOURCES=$(UNIFORM_TENSORS_SOURCES)\ + $(UNIFORM_AUTODIFF_SOURCES)\ + $(UNIFORM_EXT_OPS_SOURCES)\ + $(UNIFORM_LOADERS) + # nested-tensors NESTED_TENSORS_DIR=$(NESTED_DIR)/tensors NESTED_AUTODIFF_DIR=$(NESTED_DIR)/autodiff diff --git a/impl-no-duals.rkt b/impl-no-duals.rkt index 05043ba..5384fca 100644 --- a/impl-no-duals.rkt +++ b/impl-no-duals.rkt @@ -13,10 +13,12 @@ #,(case (tensor-implementation) ((learner) #'(require "learner/no-duals.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-duals.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner/no-duals.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-duals.rkt"))))))) (load-tensors) diff --git a/impl-no-overrides.rkt b/impl-no-overrides.rkt index 0f479e8..56a96c7 100644 --- a/impl-no-overrides.rkt +++ b/impl-no-overrides.rkt @@ -13,10 +13,12 @@ #,(case (tensor-implementation) ((learner) #'(require "learner/no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-overrides.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-overrides.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner/no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-overrides.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-overrides.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-overrides.rkt"))))))) (load-tensors) diff --git a/impl.rkt b/impl.rkt index 4f09203..4508b1f 100644 --- a/impl.rkt +++ b/impl.rkt @@ -13,10 +13,12 @@ #,(case (tensor-implementation) ((learner) #'(require "learner.rkt")) ((flat-tensors) #'(require "flat-tensors.rkt")) + ((uniform-tensors) #'(require "uniform-tensors.rkt")) ((nested-tensors) #'(require "nested-tensors.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors.rkt"))))))) (load-tensors) diff --git a/set-impl.rkt b/set-impl.rkt index 20f67b5..712c23d 100644 --- a/set-impl.rkt +++ b/set-impl.rkt @@ -7,7 +7,10 @@ (define set-impl (λ (impl) - (when (not (member impl '(learner nested-tensors flat-tensors))) + (when (not (member impl '(learner + nested-tensors + flat-tensors + uniform-tensors))) (error "Unknown implementation: ~a~%" impl)) (setup #:collections (list (list "malt")) #:clean? #t) (write-implementation-to-config-file impl) diff --git a/uniform-tensors.rkt b/uniform-tensors.rkt new file mode 100644 index 0000000..e8479a6 --- /dev/null +++ b/uniform-tensors.rkt @@ -0,0 +1,45 @@ +#lang racket/base + +(require + (except-in "uniform-tensors/tensors.rkt" + rank shape reshape tref trefs tensor? tlen ref refr)) + +(require "uniform-tensors/autodiff.rkt") +(require "uniform-tensors/ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + max-tensor-print-length make-printable + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + sum-1 argmax-1 max-1 flatten-2 concat-1-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/autodiff.rkt b/uniform-tensors/autodiff.rkt new file mode 100644 index 0000000..68168c9 --- /dev/null +++ b/uniform-tensors/autodiff.rkt @@ -0,0 +1,23 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-tref tref) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide max-tensor-print-length make-printable) diff --git a/uniform-tensors/autodiff/A-autodiff.rkt b/uniform-tensors/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..03b9c93 --- /dev/null +++ b/uniform-tensors/autodiff/A-autodiff.rkt @@ -0,0 +1,120 @@ +#lang racket + +(require "../tensors.rkt") + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (vector-ref d 1)) + (else d)))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (vector? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain map* + trace-print) diff --git a/uniform-tensors/autodiff/B-prims.rkt b/uniform-tensors/autodiff/B-prims.rkt new file mode 100644 index 0000000..e029cc9 --- /dev/null +++ b/uniform-tensors/autodiff/B-prims.rkt @@ -0,0 +1,221 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(define ρ-function + (λ (f) (f ρ-function))) + +(define ∇-function + (λ (f) (f ∇-function))) + +(define shape-fn + (λ (f) (f shape-fn))) + +;; For flat tensors, ρ-fn and ∇-fn +;; are of two types: functional and pre-allocated +;; When they are functional, they return values +;; When they are pre-allocated, they expect expect the +;; return flat-store to be pre-allocated, and simply +;; operate as fillers. +;; +;; Pre-allocated ρ and ∇ have arities +;; 6 and 7 for unary ops, and 9 and 10 for binary ops. +;; We test for this arity to determine the type. +;; +;; Generally speaking, scalar operations are functional +;; and vector operations are pre-allocated. +;; +;; The functions ensure-ρ-callable-1, ensure-∇-callable-1 +;; and ensure-ρ-callable-2, ensure-∇-callable-2 provide +;; the preallocation for flat-stores when a vector-op is +;; provided, but the invocation of prim1 expects functional +;; results. +;; + +(define prim1 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) + (∇-callable (ensure-∇-callable-1 ∇-fn shape))) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim1-dual ρ-callable ∇-callable daf))))))) + +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (let ((ga (∇-fn ra z))) + ((κ da) da ga σ))))))) + +(define prim2 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) + (∇-callable (ensure-∇-callable-2 ∇-fn shape))) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (let-values (((ga gb) (∇-fn ra rb z))) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat)))))))) + +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define ensure-ρ-callable-1 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-1 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn))) + (else ∇-fn)))) + +(define ensure-ρ-callable-2 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-2 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra rb z) + (apply-flat-∇-fn-1 ∇-fn ra rb z shape-fn))) + (else ∇-fn)))) + +(define apply-flat-ρ-fn-1 + (λ (ρ-fn ra shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-1 + (λ (∇-fn ra z shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (let ((g (new-vec in-size 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g (flat-store ra) (flat-offset ra) in-size + v-z 0 1) + (flat in-shape g 0))) + (else + (∇-fn g (flat-store ra) (flat-offset ra) in-size + (flat-store z) (flat-offset z) out-size) + (flat in-shape g 0))))))) + +(define apply-flat-ρ-fn-2 + (λ (ρ-fn ra rb shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape ra)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-2 + (λ (∇-fn ra rb z shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape ra)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-z 0 1) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))) + (else + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + (flat-store z) (flat-offset z) out-size) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- +(define ext1 + (λ (f n) + (prim1 + (ext1-ρ (ρ-function f) n (shape-fn f)) + (ext1-∇ (∇-function f) n (shape-fn f)) + (shape-fn f)))) + +(define ext2 + (λ (f m n) + (prim2 + (ext2-ρ (ρ-function f) m n (shape-fn f)) + (ext2-∇ (∇-function f) m n (shape-fn f)) + (shape-fn f)))) + +(provide prim1 prim2 ext1 ext2) diff --git a/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt b/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..5b02ba3 --- /dev/null +++ b/uniform-tensors/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tref + (λ (t i) + (tref (ρ t) i))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/uniform-tensors/autodiff/D-test-helpers.rkt b/uniform-tensors/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..5b21797 --- /dev/null +++ b/uniform-tensors/autodiff/D-test-helpers.rkt @@ -0,0 +1,47 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(require rackunit) + +(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args))) + (cond + ((and (equal-wt? ans (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (ρ g) grads))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (ρ y) ans)))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/uniform-tensors/autodiff/E-print.rkt b/uniform-tensors/autodiff/E-print.rkt new file mode 100644 index 0000000..7f4090e --- /dev/null +++ b/uniform-tensors/autodiff/E-print.rkt @@ -0,0 +1,87 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "A-autodiff.rkt") +(require "../tensors.rkt") + +(define max-tensor-print-length (make-parameter 5)) + +(struct fake-tensor (members) + #:transparent + #:methods gen:custom-write + ((define write-proc + (λ (fake-tensor port mode) + (let ((n (length (fake-tensor-members fake-tensor)))) + (case mode + ((#t) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (write m port)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + ((#f) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (display m port) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + (else + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (print m port mode)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)))))))) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((flat? y) (make-printable-flat y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(define make-printable-flat + (λ (y max-length) + (flat->tensor-list + (flat-store y) (flat-offset y) (flat-shape y) + (strides (flat-shape y)) max-length))) + +(define flat->tensor-list + (λ (store offset shape strides max-length) + (cond + ((null? shape) (vref store offset)) + (else + (let ((top-len (car shape)) + (stride (car strides))) + (fake-tensor + (reverse + (call/cc + (λ (return) + (for/fold ((lst '())) ((i (in-range offset (+ offset (* top-len stride)) stride)) + (count (in-naturals 0))) + (cond + ((and (> max-length 0) (= count max-length)) (return (cons '... lst))) + (else + (cons (flat->tensor-list store i (cdr shape) (cdr strides) max-length) + lst))))))))))))) + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/uniform-tensors/autodiff/test/test-A-autodiff.rkt b/uniform-tensors/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/uniform-tensors/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/uniform-tensors/autodiff/test/test-E-print.rkt b/uniform-tensors/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..08c7c8a --- /dev/null +++ b/uniform-tensors/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable-flat deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...))) + + (check-equal? (make-printable-flat deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + '...)))))) diff --git a/uniform-tensors/ext-impl.rkt b/uniform-tensors/ext-impl.rkt new file mode 100644 index 0000000..6eb7d34 --- /dev/null +++ b/uniform-tensors/ext-impl.rkt @@ -0,0 +1,28 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ + idxs + scalarize + ensure-flat)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/uniform-tensors/ext-ops.rkt b/uniform-tensors/ext-ops.rkt new file mode 100644 index 0000000..83af7de --- /dev/null +++ b/uniform-tensors/ext-ops.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/uniform-tensors/ext-ops/A-scalar-ops.rkt b/uniform-tensors/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..b251e80 --- /dev/null +++ b/uniform-tensors/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,135 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0 + (prim2 + + (λ (a b z) + (values z z)))) + +(define --0-0 + (prim2 - + (λ (a b z) + (values z (- z))))) + +(define *-0-0 + (prim2 * + (λ (a b z) + (values (* b z) (* a z))))) + +(define /-0-0 + (prim2 / + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))))) + +(define expt-0-0 + (prim2 expt + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))))) + +(define exp-0 + (prim1 exp + (λ (a z) + (* z (exp a))))) + +(define log-0 + (prim1 log + (λ (a z) + (* z (/ 1 a))))) + +(define sqrt-0 + (prim1 sqrt + (λ (x z) + (/ z (* 2 (sqrt x)))))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0 + (prim1 abs-0-ρ abs-0-∇)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) +(define d-sqrt (ext1 sqrt-0 0)) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * 0 0)) +(define +-ρ (ext2-ρ + 0 0)) +(define --ρ (ext2-ρ - 0 0)) +(define /-ρ (ext2-ρ / 0 0)) +(define expt-ρ (ext2-ρ expt 0 0)) + +(define exp-ρ (ext1-ρ exp 0)) +(define log-ρ (ext1-ρ log 0)) +(define abs-ρ (ext1-ρ abs-0-ρ 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) + +(define sqrt-ρ + (λ (a) + (expt-ρ a 1/2))) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) diff --git a/uniform-tensors/ext-ops/B-comparators.rkt b/uniform-tensors/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..c42a2cf --- /dev/null +++ b/uniform-tensors/ext-ops/B-comparators.rkt @@ -0,0 +1,85 @@ +#lang racket + +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f) + (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + +(define extended-comparator + (λ (f) + (ext2 (comparator-prim f) 0 0))) + +(define =-1 + (extended-comparator =)) + +(define <-1 + (extended-comparator <)) + +(define >-1 + (extended-comparator >)) + +(define <=-1 + (extended-comparator <=)) + +(define >=-1 + (extended-comparator >=)) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator !=)) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/ext-ops/C-star-2-1.rkt b/uniform-tensors/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..5eb0d63 --- /dev/null +++ b/uniform-tensors/ext-ops/C-star-2-1.rkt @@ -0,0 +1,44 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (let ((a (vref v0 (+ i0 i))) + (b (vref v1 (+ i1 (modulo i stride1)))) + (z (vref vz (+ iz i)))) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) (* z b))) + (vset! g1 (+ i1 (modulo i stride1)) + (+ (vref g1 (+ i1 (modulo i stride1))) (* z a))))))) + +(define *-2-1-shape + (λ (s t) + s)) + +(define *-2-1 + (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape)) + +(define d*-2-1 + (ext2 *-2-1 2 1)) + +(define *-2-1-ρ + (ext2-ρ *-2-1-base-ρ 2 1 *-2-1-shape)) + +(include "test/test-C-star-2-1.rkt") + +(provide *-2-1-ρ d*-2-1) diff --git a/uniform-tensors/ext-ops/D-sum.rkt b/uniform-tensors/ext-ops/D-sum.rkt new file mode 100644 index 0000000..44c6b8e --- /dev/null +++ b/uniform-tensors/ext-ops/D-sum.rkt @@ -0,0 +1,68 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define sum-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([sum 0.0]) ([i (in-range i0 (+ i0 stride0))]) + (+ sum (vref v0 i)))))) + +(define sum-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i + (+ (vref g0 i) z)))))) + +(define sum-shape + (λ (st) + (refr st 1))) + +(define sum-1 + (prim1 sum-1-ρ sum-1-∇ sum-shape)) + +(define d-sum + (ext1 sum-1 1)) + +(define sum-ρ + (ext1-ρ sum-1-ρ 1 sum-shape)) + +(provide d-sum sum-ρ) + +(define sum-cols-2-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (for ((i (in-range 0 stride-out))) + (vset! v-out (+ i i-out) + (for/fold ([sum 0.0]) ([j (in-range i0 (+ i0 stride0) stride-out)]) + (+ sum (vref v0 (+ j i)))))))) + +(define sum-cols-2-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (for ((i (in-range 0 stride-z))) + (for ([j (in-range i0 (+ i0 stride0) stride-z)]) + (vset! g0 (+ i j) + (+ (vref g0 (+ i j)) (vref vz (+ i iz)))))))) + +(define sum-cols-shape + (λ (s) + (refr s 1))) + +(define sum-cols-2 + (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape)) + +(define d-sum-cols + (ext1 sum-cols-2 2)) + +(define sum-cols-ρ + (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape)) + +(include "test/test-D-sum.rkt") + +(provide sum-1 d-sum-cols sum-cols-ρ) diff --git a/uniform-tensors/ext-ops/E-argmax.rkt b/uniform-tensors/ext-ops/E-argmax.rkt new file mode 100644 index 0000000..d68966d --- /dev/null +++ b/uniform-tensors/ext-ops/E-argmax.rkt @@ -0,0 +1,41 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define argmax-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max 0.0] + [max-i -1] #:result max-i) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i 0.0))))) + +(define argmax-shape + (λ (st) + '())) + +(define argmax-1 + (prim1 argmax-1-ρ argmax-1-∇ argmax-shape)) + +(define d-argmax + (ext1 argmax-1 1)) + +(define argmax-ρ + (ext1-ρ argmax-1-ρ 1 argmax-shape)) + +(include "test/test-E-argmax.rkt") + +(provide argmax-1 d-argmax argmax-ρ) diff --git a/uniform-tensors/ext-ops/F-max.rkt b/uniform-tensors/ext-ops/F-max.rkt new file mode 100644 index 0000000..2d5f5d6 --- /dev/null +++ b/uniform-tensors/ext-ops/F-max.rkt @@ -0,0 +1,49 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define max-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max 0.0]) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) v) + (else max))))))) + +(define max-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for/fold ([max -inf.0] + [max-i -1] #:result + (for ([i (in-range i0 (+ i0 stride0))]) + (cond + ((= i (+ i0 max-i)) (vset! g0 i z)) + (else (vset! g0 i 0.0))))) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-∇ max-shape)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ 1 max-shape)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/uniform-tensors/ext-ops/G-correlate.rkt b/uniform-tensors/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..9db2109 --- /dev/null +++ b/uniform-tensors/ext-ops/G-correlate.rkt @@ -0,0 +1,95 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vset! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-∇ nd md qd) + correlate-shape))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/uniform-tensors/ext-ops/I-flatten.rkt b/uniform-tensors/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..bf24773 --- /dev/null +++ b/uniform-tensors/ext-ops/I-flatten.rkt @@ -0,0 +1,31 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-∇ + (λ (t z) + (reshape (shape t) z))) + +(define flatten-shape + (λ (s) + (let ((rows (ref s 0)) + (cols (ref s 1))) + (list (* rows cols))))) + +(define flatten-2 + (prim1 flatten-2-ρ flatten-2-∇ flatten-shape)) + +(define d-flatten + (ext1 flatten-2 2)) + +(define flatten-ρ + (ext1-ρ flatten-2-ρ 2)) + +(include "test/test-I-flatten.rkt") + +(provide flatten-2 d-flatten flatten-ρ) diff --git a/uniform-tensors/ext-ops/K-concat.rkt b/uniform-tensors/ext-ops/K-concat.rkt new file mode 100644 index 0000000..e8e3a99 --- /dev/null +++ b/uniform-tensors/ext-ops/K-concat.rkt @@ -0,0 +1,76 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (rename-in (only-in "../tensors.rkt" ext2-ρ tref tlen shape len ref) + (shape shape-ρ))) +(require (only-in "../autodiff.rkt" prim2 ext2 shape)) + +(define concat-shape + (λ (st su) + (cons (+ (ref st 0) (ref su 0)) + (cdr st)))) + +(define concat-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (cond + ((< i stride0) + (vector-set! v-out (+ i-out i) (vector-ref v0 (+ i0 i)))) + (else + (vector-set! v-out (+ i-out i) (vector-ref v1 (+ i1 (- i stride0))))))))) + +(define concat-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (cond + ((< i stride0) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) + (vref vz (+ iz i))))) + (else + (vset! g1 (+ i1 (- i stride0)) + (+ (vref g1 (+ i1 (- i stride0))) + (vref vz (+ iz i))))))))) + +(define concat-base + (prim2 concat-base-ρ concat-base-∇ concat-shape)) + +(define d-concat-n + (λ (n) + (λ (t u) + (let ((st (shape t)) + (su (shape u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base n n) t u))))) + +(define concat-n-ρ + (λ (n) + (λ (t u) + (let ((st (shape-ρ t)) + (su (shape-ρ u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base-ρ n n concat-shape) t u))))) + +(define ensure-compatible-shapes + (λ (n st su) + (let ((rt (len st)) + (ru (len su))) + ;; The shape of the tensor of rank r at rank n-1 + ;; is given by (drop st (+ 1 (- r n))) + (when (not (equal? (drop st (+ 1 (- rt n))) (drop su (+ 1 (- ru n))))) + (error 'concat "Incompatible concat shapes: ~a and ~a at last ~a dimensions" + st su n))))) + +(define d-concat (d-concat-n 1)) +(define concat-ρ (concat-n-ρ 1)) +(define concat-1-1 concat-base) + +(include "test/test-K-concat.rkt") + +(provide concat-1-1 + d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt b/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt new file mode 100644 index 0000000..2c13e39 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-A-scalar-ops.rkt @@ -0,0 +1,122 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + ;; Check basic numericals + (let ((a 2) + (b 3)) + (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) + (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) + (check-ρ-∇ (d/ a b) + 2/3 + '(0.3333333333333333 -0.2222222222222222)) + (check-ρ-∇ (d-exp a) (exp 2) (list (exp 2))) + (check-ρ-∇ (d-log a) (log 2) (list 0.5)) + (check-ρ-∇ (d-expt a b) 8 + (list 12.0 5.545177444479562)) + (check-ρ-∇ (d-sqrt a) + (sqrt 2) + (list 0.3535533905932738)) + + (let* ((first-derivative ((∇¹ d-sqr) b))) + (check-dual-equal? first-derivative (list 6.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* 2)) + (d (dual* 2))) + (check-dual-equal? + ((∇¹ z) c d) + (list 2.5 2.0))) + + (check-dual-equal? ((∇¹ z) a a) (list 2.5 2.0))) + + ;; Check numericals with vector-duals + + (let ((a (tensor 2.0 3.0 4.0)) + (b (tensor 3.0 8.0 9.0))) + (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + (check-ρ-∇ (d/ a b) + (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + (check-ρ-∇ (d-expt a b) + (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + (list (tensor 12.0 17496.0 589824.0) + (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + (check-ρ-∇ (d-sqrt a) + (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* (tensor 2.0 3.0 4.0))) + (d (dual* (tensor 2.0 3.0 4.0)))) + (check-dual-equal? + ((∇¹ z) c d) + (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + ;; Check numericals with lists + (let ((x (dual* 3)) + (y (dual* 2))) + (let ((f d*)) + (check-ρ-∇ (d* x y) 6 '(2.0 3.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list (f m m) (f n n)))) + x y) + '(6.0 4.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list + (list (f m m) (f n n)) + (list (f m m) (f n n))))) + x y) + '(12.0 8.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (map (λ (m) (f m m)) + (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list n m)))) + x y) + '(4.0 6.0)))) + + (let ((a 7) + (b (tensor 13))) + (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) + (check-ρ-∇ (d* a b) (tensor 91) (list 13.0 (tensor 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13) (list 0.07692 (tensor -0.04142)))) + + (let ((a 7) + (b (tensor 13 15))) + (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) + (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) + (list 0.14358 (tensor -0.04142 -0.03111))))) diff --git a/uniform-tensors/ext-ops/test/test-B-comparators.rkt b/uniform-tensors/ext-ops/test/test-B-comparators.rkt new file mode 100644 index 0000000..9f3fdf5 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-B-comparators.rkt @@ -0,0 +1,12 @@ +(module+ test + (require rackunit) + (let ((a 2) + (b 3)) + (check-true (<-0-0 a b)) + (check-false (>-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt b/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..bbb2c8b --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,24 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-D-sum.rkt b/uniform-tensors/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..7b07d0d --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,58 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (sum-1 b) 14 + (list (tensor 1.0 1.0 1.0 1.0))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-E-argmax.rkt b/uniform-tensors/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..f72819e --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,17 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/uniform-tensors/ext-ops/test/test-F-max.rkt b/uniform-tensors/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..01ab1a5 --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-F-max.rkt @@ -0,0 +1,10 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/uniform-tensors/ext-ops/test/test-G-correlate.rkt b/uniform-tensors/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..417723c --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/uniform-tensors/ext-ops/test/test-I-flatten.rkt b/uniform-tensors/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..f7740cf --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,13 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/uniform-tensors/ext-ops/test/test-K-concat.rkt b/uniform-tensors/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..b427ced --- /dev/null +++ b/uniform-tensors/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,126 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/uniform-tensors/no-duals-no-overrides.rkt b/uniform-tensors/no-duals-no-overrides.rkt new file mode 100644 index 0000000..ac07a7a --- /dev/null +++ b/uniform-tensors/no-duals-no-overrides.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/no-duals.rkt b/uniform-tensors/no-duals.rkt new file mode 100644 index 0000000..927c8c7 --- /dev/null +++ b/uniform-tensors/no-duals.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/no-overrides.rkt b/uniform-tensors/no-overrides.rkt new file mode 100644 index 0000000..35dcbdd --- /dev/null +++ b/uniform-tensors/no-overrides.rkt @@ -0,0 +1,43 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tref tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt d-sqr + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ concat-n-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/tensors.rkt b/uniform-tensors/tensors.rkt new file mode 100644 index 0000000..0daca26 --- /dev/null +++ b/uniform-tensors/tensors.rkt @@ -0,0 +1,22 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require "tensors/A-equality.rkt") +(require "tensors/B-tensor-basics.rkt") +(require "tensors/C-tensor-ops.rkt") +(require "tensors/D-extend.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇ expects-preallocated?) + +(provide flat flat? flat-shape flat-store flat-offset size-of strides) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) diff --git a/uniform-tensors/tensors/0-vectors.rkt b/uniform-tensors/tensors/0-vectors.rkt new file mode 100644 index 0000000..2df0e8e --- /dev/null +++ b/uniform-tensors/tensors/0-vectors.rkt @@ -0,0 +1,85 @@ +#lang racket +(require ffi/vector) + +;;------------------------------------------------ +;; Raw representation of vectors +;;------------------------------------------------ + +(define vec? f32vector?) +(define vec f32vector) +(define make-vec make-f32vector) +(define vref f32vector-ref) +(define vset! f32vector-set!) +(define vlen f32vector-length) +(define list->vec list->f32vector) +(define build-vec + (λ (n proc) + (list->vec (map proc (range n))))) +(define vec->cpointer f32vector->cpointer) + +(define-for-syntax debug-leaks? #f) +(define-syntax when-debug-leaks + (λ (x) + (syntax-case x () + ((when-debug-leaks expr) + debug-leaks? + #'expr) + ((when-debug-leaks expr) + #'(void))))) + +(define new-vec + (λ (size initial-value [context 'new-vec]) + (let ((m (make-vec size initial-value))) + (when-debug-leaks (manage-flat-vector! m context)) + m))) + +(define vcopy + (λ (dest idest src isrc n) + (for ([id (in-range idest (+ n idest))] + [is (in-range isrc (+ n isrc))]) + (vset! dest id (vref src is))))) + +(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer new-vec) + +;;------------------------------------------------ +;; Memory management for flat-vectors +;;------------------------------------------------ + +(define flat-vector-manager + (make-will-executor)) + +(define manage-flat-vector! + (λ (m context) + (set-count! context (add1 (count context))) + (will-register flat-vector-manager m (flat-vector-collector context)))) + +(define flat-vector-collector + (λ (context) + (λ (v) + (cond + ((vector? v) + (set-count! context (sub1 (count context)))) + (else (fprintf (current-error-port) "?? ...")))))) + +(define start-vector-manager + (λ () + (when-debug-leaks + (void + (thread + (λ () + (let loop () + (will-execute flat-vector-manager) + (loop)))))))) + +(define counts (make-hash)) +(define count (λ (context) (dict-ref counts context 0))) +(define set-count! (λ (context v) (dict-set! counts context v))) +(define vector-manager-report + (λ () + (fprintf (current-error-port) "----------------------------------------------~%") + (fprintf (current-error-port) "context\t\t\tcount~%") + (for ([(context count) (in-hash counts)]) + (fprintf (current-error-port) "~a\t\t\t~a~%" context count)) + (fprintf (current-error-port) "----------------------------------------------~%"))) + +(provide start-vector-manager vector-manager-report) diff --git a/uniform-tensors/tensors/1-flats.rkt b/uniform-tensors/tensors/1-flats.rkt new file mode 100644 index 0000000..23aeb79 --- /dev/null +++ b/uniform-tensors/tensors/1-flats.rkt @@ -0,0 +1,73 @@ +#lang racket + +;-------------------------------------------------------- +; Representation of tensors +;-------------------------------------------------------- + +;; A flat tensor representation is for a contiguous slice in the backing store. +;; The fields we need: +;; shape : list +;; store : vector +;; offset : start of the contiguous slice. +;; size : number of elements in the contiguous slice +;; strides : Number of elements in each dimension of the tensor. +;; rank: Number of dimensions in the tensor + + + +(define flat + (λ (shape store offset) + (vector flat shape store offset + (size-of shape) + (strides shape) + (length shape)))) + +(define flat? + (λ (v) + (and (vector? v) + (eq? (vector-ref v 0) flat)))) + +(define flat-shape + (λ (f) + (vector-ref f 1))) + +(define flat-store + (λ (f) + (vector-ref f 2))) + +(define flat-offset + (λ (f) + (vector-ref f 3))) + +(define flat-size + (λ (f) + (vector-ref f 4))) + +(define flat-strides + (λ (f) + (vector-ref f 5))) + +(define flat-rank + (λ (f) + (vector-ref f 6))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(provide flat flat? flat-shape flat-store + flat-offset flat-rank flat-strides flat-size + size-of strides) diff --git a/uniform-tensors/tensors/A-equality.rkt b/uniform-tensors/tensors/A-equality.rkt new file mode 100644 index 0000000..0d2300a --- /dev/null +++ b/uniform-tensors/tensors/A-equality.rkt @@ -0,0 +1,75 @@ +#lang racket + +;;—————————————————–—————————————————–—————————————————– +;; Equality checks for mostly for testing. +;;—————————————————–—————————————————–—————————————————– + +(require "0-vectors.ss") +(require "1-flats.ss") +(require rackunit) + +;;—————————————————–—————————————————–—————————————————– +;; These parameters can be overriden to account for +;; different type of numbers used inside tensors. +;;—————————————————–—————————————————–—————————————————– + +(define tolerance (make-parameter 0.0001)) + +;;TODO: Discuss if this is the right fix, because without this 2 tests in +;;A-core.rkt fail. (Also delete the printfs later) +(tolerance 0.001) + +(define equal-within-tolerance? + (make-parameter + (λ (actual expected) + (< (abs (- actual expected)) (tolerance))))) + +;;—————————————————–—————————————————–—————————————————– +;; These are representation specific, but part of the +;; exported interface of the module +;;—————————————————–—————————————————–—————————————————– + +(define tensor-equal? + (λ (actual expected) + (or (equal? actual expected) + (and (real? actual) + (real? expected) + ((equal-within-tolerance?) actual expected)) + (and (flat? actual) + (flat? expected) + (equal? (flat-shape actual) + (flat-shape expected)) + (equal-elements? actual expected))))) + +(define (equal-elements? actual expected) + (let ((actual-offset (flat-offset actual)) + (expected-offset (flat-offset expected)) + (actual-size (flat-size actual)) + (expected-size (flat-size expected)) + (actual-store (flat-store actual)) + (expected-store (flat-store expected))) + ;(printf "###(equal-elements? ~a ~a)~n" actual expected) + (and (equal? actual-size expected-size) + (call/cc (λ (return) + (for/fold ([check #t]) + ([i-actual (in-range actual-offset + (+ actual-offset + actual-size))] + [i-expected (in-range expected-offset + (+ expected-offset + expected-size))]) + (define actual-elem (vref actual-store i-actual)) + (define expected-elem (vref expected-store i-expected)) + ;(printf "###actual: ~a \texpected: ~a~n" actual-elem expected-elem) + ;(printf "### |actual - expected| = ~a~n" (abs (- actual-elem expected-elem))) + (cond + (((equal-within-tolerance?) + actual-elem + expected-elem) check) + (else (return #f))))))))) + +(define-binary-check (check-tensor-equal? tensor-equal? actual expected)) + +(include "test/test-A-equality.rkt") + +(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal?) diff --git a/uniform-tensors/tensors/B-tensor-basics.rkt b/uniform-tensors/tensors/B-tensor-basics.rkt new file mode 100644 index 0000000..63e9a71 --- /dev/null +++ b/uniform-tensors/tensors/B-tensor-basics.rkt @@ -0,0 +1,184 @@ +#lang racket + +;-------------------------------------------------------- +; Memory management tools for vectors +;-------------------------------------------------------- + +(require "0-vectors.ss") +(require "1-flats.ss") + +;-------------------------------------------------------- +; Lists +;-------------------------------------------------------- +(define ref list-ref) +(define refr drop) +(define len length) + +(provide ref refr len) + +;-------------------------------------------------------- +; Tensor basics +;-------------------------------------------------------- + +(define tref + (λ (t i) + (cond + ((= 1 (flat-rank t)) + (vref (flat-store t) (+ (flat-offset t) i))) + (else + (flat (cdr (flat-shape t)) + (flat-store t) + (+ (flat-offset t) (* i (car (flat-strides t))))))))) + +(define tlen + (λ (t) + (car (flat-shape t)))) + +(define flat-ref-idx + (λ (v indices) + (flat-ref-idx* (flat-offset v) (flat-strides v) indices))) + +(define flat-ref-idx* + (λ (current-idx strides indices) + (cond + ((null? indices) current-idx) + (else + (flat-ref-idx* + (+ current-idx + (* (car indices) (car strides))) + (cdr strides) + (cdr indices)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define list->tensor + (λ (lst) + (cond + ((null? lst) (error 'list->flat-tensor "No elements found")) + ((number? (car lst)) + (flat (list (length lst)) (list->vec lst) 0)) + (else + (flat-tensor-from-list lst))))) + +(define flat-tensor-from-list + (λ (lst) + (let* ([inner-shape (flat-shape (car lst))] + [inner-size (size-of inner-shape)] + [outer-shape (cons (length lst) inner-shape)] + [size (size-of outer-shape)] + [v (new-vec size 0.0 'from-list)]) + (for ([fl lst] + [i (in-naturals 0)]) + (vcopy v (* i inner-size) + (flat-store fl) (flat-offset fl) + inner-size)) + (flat outer-shape v 0)))) + +(define tensor? + (λ (t) + (or (number? t) + (flat? t)))) + +(define tensor + (λ args + (ensure-shape args) + (cond + ((number? (car args)) (flat (list (length args)) + (list->vec (map exact->inexact args)) + 0)) + (else (merge-flats args))))) + +(define merge-flats + (λ (args) + (let* ((inner-shape (flat-shape (car args))) + (outer (length args)) + + (new-shape (cons outer inner-shape)) + (stride (size-of inner-shape)) + + (new-size (size-of new-shape)) + + (v-out (new-vec new-size +nan.0 'tensor))) + (for ([i-out (in-range outer)] + [arg args]) + (vcopy v-out (* i-out stride) (flat-store arg) (flat-offset arg) stride)) + (flat new-shape v-out 0)))) + +(define ensure-shape + (λ (args) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (flat? x) + (flat-shape x) + '())))) + (unless (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((flat? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (flat? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +(define build-tensor + (λ (shape f) + (let* ((size (size-of shape)) + (v (new-vec size 0.0 'build-tensor)) + (strides (strides shape))) + (fill-flat-tensor v shape strides f 0 '()) + (flat shape v 0)))) + +(define fill-flat-tensor + (λ (dest shape strides f offset tidx) + (cond + ((null? (cdr shape)) + (for ([i (in-range 0 (car shape))]) + (vset! dest (+ offset i) + (exact->inexact (f (append tidx (list i))))))) + (else + (let ((stride (car strides))) + (for ([i (in-range 0 (car shape))]) + (fill-flat-tensor dest + (cdr shape) (cdr strides) f + (+ offset (* i stride)) (append tidx (list i))))))))) + +(define trefs + (λ (t b) + (let* ([st (flat-shape t)] + [est (cdr st)] + [estride (size-of est)] + [nshape (cons (length b) (cdr st))] + [size-out (size-of nshape)] + [v-out (new-vec size-out 0.0 'flat-refs)] + [vt (flat-store t)]) + (for ([ib b] + [i-out (in-range 0 size-out estride)]) + (vcopy v-out i-out vt (* ib estride) estride)) + (flat nshape v-out 0)))) + +(include "test/test-B-tensor-basics.rkt") + +(provide tref tlen list->tensor number? + tensor? tensor build-tensor trefs merge-flats) diff --git a/uniform-tensors/tensors/C-tensor-ops.rkt b/uniform-tensors/tensors/C-tensor-ops.rkt new file mode 100644 index 0000000..f9056b4 --- /dev/null +++ b/uniform-tensors/tensors/C-tensor-ops.rkt @@ -0,0 +1,35 @@ +#lang racket + +(require "1-flats.ss") +(require "B-tensor-basics.ss") + +;;—————————————————– +;; Shape, rank, size-of +;;—————————————————– + +(define shape + (λ (t) + (cond + ((number? t) '()) + (else (flat-shape t))))) + +(define rank + (λ (t) + (len (shape t)))) + +;;—————————————————– +;; Reshape a tensor +;;—————————————————– + +(define reshape + (λ (s t) + (cond + ((= (size-of s) (flat-size t)) + (flat s (flat-store t) (flat-offset t))) + (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + + +(include "test/test-C-tensor-ops.rkt") + +(provide rank shape reshape) +(provide size-of strides) diff --git a/uniform-tensors/tensors/D-extend.rkt b/uniform-tensors/tensors/D-extend.rkt new file mode 100644 index 0000000..07d955d --- /dev/null +++ b/uniform-tensors/tensors/D-extend.rkt @@ -0,0 +1,391 @@ +#lang racket + +(require "0-vectors.ss") +(require "1-flats.ss") +(require "B-tensor-basics.ss") +(require "C-tensor-ops.ss") + +;;—————————————————–—————————————————–—————————————————– +;; Unary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext1-ρ + (λ (f m [shape-fn scalar-shape]) + (λ (t) + (cond + ((number? t) (f t)) + ((expects-preallocated? f) + (scalarize + (flat-ext1-ρ f m shape-fn t))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f m shape-fn t)))))))) + +(define ext1-∇ + (λ (f m [shape-fn scalar-shape]) + (λ (t z) + (cond + ((number? t) (f t z)) + ((expects-preallocated? f) + (scalarize (flat-ext1-∇ f m shape-fn t (ensure-flat z)))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f m shape-fn t (ensure-flat z))))))))) + +(define functional->preallocated-1-ρ + (λ (f base-shape out-shape) + (λ (v0 i0 stride0 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value base-shape v0 i0)))))) + +(define functional->preallocated-1-∇ + (λ (f base-shape out-shape) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value base-shape v0 i0))) + (set-prealloc-∇! g0 i0 base-shape (f a z)))))) + +(define set-prealloc-ρ! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out a)) + (else (v-copy-flat! v-out i-out a))))) + +(define set-prealloc-∇! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) + (else (v-add-flat! v-out i-out a))))) + +(define arg-value + (λ (v-shape v i) + (cond + ((null? v-shape) (vref v i)) + (else (flat v-shape v i))))) + + +(define invoke-functional-∇ + (λ (f base-shape v0 i0) + (cond + ((null? base-shape) (f (vref v0 i0))) + (else (f (flat base-shape v0 i0)))))) + +;;—————————————————–—————————————————–—————————————————– +;; Binary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext2-ρ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u) + (cond + ((and (number? t) (number? u)) (f t u)) + ((expects-preallocated? f) + (scalarize + (flat-ext2-ρ f m n shape-fn t u))) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn (ensure-flat t) u)))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t (ensure-flat u))))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t u)))))))) + +(define ext2-∇ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u z) + (let ((invoke-flat-ext2-∇ + (λ (f m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f m n shape-fn t u z))) + (values (scalarize da) (scalarize db)))))) + (cond + ((and (number? t) (number? u)) (f t u z)) + ((expects-preallocated? f) + (invoke-flat-ext2-∇ f m n shape-fn t u z)) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn (ensure-flat t) u z))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t (ensure-flat u) z))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t u z)))))))) + +(define functional->preallocated-2-ρ + (λ (f t-shape u-shape out-shape) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value t-shape v0 i0) + (arg-value u-shape v1 i1)))))) + +(define functional->preallocated-2-∇ + (λ (f t-shape u-shape out-shape) + (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value t-shape v0 i0)) + (b (arg-value u-shape v1 i1))) + (let-values (((da db) (f a b z))) + (set-prealloc-∇! g0 i0 t-shape da) + (set-prealloc-∇! g1 i1 u-shape db)))))) + +(define idxs + (λ (strides out-i i0 i1) + (for/fold ([i0 i0] + [i1 i1] + [x out-i] #:result (values i0 i1)) + ([stride strides]) + (let ((idx (quotient x (vector-ref stride 0))) + (next-x (remainder x (vector-ref stride 0)))) + (values (+ i0 (* idx (vector-ref stride 1))) + (+ i1 (* idx (vector-ref stride 2))) + next-x))))) + +(define merge-shapes + (λ (in-shape min-rank out-f-shape) + (append (take in-shape (- (length in-shape) min-rank)) + out-f-shape))) + +(define flat-ext1-ρ + (λ (f min-rank shape-fn t0) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sf-out (shape-fn sf0)) + (stride-out (size-of sf-out)) + (s-out (merge-shapes s0 min-rank sf-out)) + (size-out (size-of s-out)) + (v-out (new-vec size-out 0.0))) + (for ([i-out (in-range 0 size-out stride-out)] + [i0 (in-range off0 (+ off0 size0) stride0)]) + (f v0 i0 stride0 v-out i-out stride-out)) + (flat s-out v-out 0)))) + +(define flat-ext1-∇ + (λ (fᵈ min-rank shape-fn t0 z) + ;; z has the same shape as the output + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sz (flat-shape z)) + (size-z (size-of sz)) + (sf-z (shape-fn sf0)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + + (g0 (new-vec size0 0.0))) + (for ([iz (in-range 0 size-z stride-z)] + [i0 (in-range off0 (+ off0 size0) stride0)]) + (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (flat s0 g0 0)))) + +(define flat-ext2-ρ + (λ (f r0 r1 shape-fn t0 t1) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + + (sf-out (shape-fn sf0 sf1)) + (stride0 (size-of sf0)) + (stride1 (size-of sf1)) + (stride-out (size-of sf-out))) + (ext2-shapes s0 s1 r0 r1 sf-out + (λ (s-out size-out q0 q1 strides) + (let ((out-v (new-vec size-out 0.0))) + (for ([out-i (in-range 0 size-out stride-out)]) + (let-values (((i0 i1) + (idxs strides out-i off0 off1))) + (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))) + (flat s-out out-v 0))))))) + +(define flat-ext2-∇ + (λ (fᵈ r0 r1 shape-fn t0 t1 z) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + (stride0 (size-of sf0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + (stride1 (size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z))) + (ext2-shapes s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (size-of s0) 0.0)) + (g1 (new-vec (size-of s1) 0.0))) + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) + (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))) + (values (flat s0 g0 0) + (flat s1 g1 0)))))))) + +(define ext2-shapes + (λ (s0 s1 r0 r1 sf-out k) + (let ((l0 (length s0)) + (l1 (length s1))) + (cond + ((and (= r0 l0) (= r1 l1)) + (k sf-out + (size-of sf-out) + (size-of s0) + (size-of s1) + '())) + + ((= r0 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((= r1 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + ((and (not (null? s0)) + (not (null? s1)) + (= (car s0) (car s1))) + (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out + (desc-both (car s0) k))) + + ((> l1 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((> l0 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + (else (error 'ext + "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" + s0 s1 r0 r1)))))) + +(define desc-both + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + (* q1 d) + (cons (vector qout q0 q1) strides))))) + +(define desc-left + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + q1 + (cons (vector qout q0 0) strides))))) + +(define desc-right + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + q0 + (* q1 d) + (cons (vector qout 0 q1) strides))))) + +(define v-copy-flat! + (λ (vg ig a) + ;; copy elements from a to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (vref va (+ a-offset i))))))) + +(define v-add-flat! + (λ (vg ig a) + ;; copy elements to a to vg while adding them to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (+ (vref vg (+ ig i)) + (vref va (+ a-offset i)))))))) + +(define expects-preallocated? + (λ (f) + (let ((a (procedure-arity f))) + (and (integer? a) + (>= a 6))))) + +(define ensure-flat + (λ (z) + (cond + ((number? z) + (flat '() (new-vec 1 (exact->inexact z)) 0)) + (else z)))) + +(define scalarize + (λ (t) + (cond + ((null? (flat-shape t)) (vref (flat-store t) 0)) + (else t)))) + +(define min-shape + (λ (min-rank in-shape) + (drop in-shape (- (length in-shape) min-rank)))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +(include "test/test-D-extend.rkt") + +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/uniform-tensors/tensors/test/test-A-equality.rkt b/uniform-tensors/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..9488d8c --- /dev/null +++ b/uniform-tensors/tensors/test/test-A-equality.rkt @@ -0,0 +1,84 @@ +(module+ test + (require rackunit) + + (check-true ((equal-within-tolerance?) 1.00001 1.0001)) + (check-true ((equal-within-tolerance?) 1.0002 1.0001)) + (check-false ((equal-within-tolerance?) 1.0003 1.0001)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.0 i))) + 0)) + + + (define t1 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t2 + (flat '(1 2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t3 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (* (quotient i 24) i))) + 0)) + + (define t4 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (- (* 2.000001 (* (quotient i 24) i)) 48.0))) + 0)) + + (check-true (equal-elements? t0 t1)) + + (check-true (equal-elements? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (equal-elements? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (equal-elements? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (equal-elements? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-true (tensor-equal? t0 t1)) + + (check-false (tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (tensor-equal? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-tensor-equal? t0 t1) + + (check-tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0)) + + (check-tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) diff --git a/uniform-tensors/tensors/test/test-B-tensor-basics.rkt b/uniform-tensors/tensors/test/test-B-tensor-basics.rkt new file mode 100644 index 0000000..903220b --- /dev/null +++ b/uniform-tensors/tensors/test/test-B-tensor-basics.rkt @@ -0,0 +1,33 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (check-tensor-equal? (tref r1-td 2) 5.0) + (check-equal? (tlen r1-td) 3) + (check-tensor-equal? (list->tensor (list 3.0 4.0 5.0)) r1-td) + + (check-true (and (tensor? r0-td) (tensor? r1-td))) + (check-false (tensor? '(a b c))) + + (check-tensor-equal? (build-tensor '(4 3 2) + (λ (idx) + (+ (* 6 (ref idx 0)) + (* 2 (ref idx 1)) + (ref idx 2)))) + r3-td) + + (check-tensor-equal? (build-tensor '(1 2 3) (λ (idx) (+ (list-ref idx 0) (list-ref idx 1) (list-ref idx 2)))) + (tensor (tensor (tensor 0 1 2) (tensor 1 2 3)))) + + (check-tensor-equal? (trefs r1-td '(0 2)) (tensor 3.0 5.0)) + + ) diff --git a/uniform-tensors/tensors/test/test-C-tensor-ops.rkt b/uniform-tensors/tensors/test/test-C-tensor-ops.rkt new file mode 100644 index 0000000..4509a31 --- /dev/null +++ b/uniform-tensors/tensors/test/test-C-tensor-ops.rkt @@ -0,0 +1,63 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (define test-shape (list 2 2 3)) + + (check-equal? (shape r0-td) (list)) + (check-equal? (shape r1-td) (list 3)) + (check-equal? (shape r2-td) (list 2 3)) + + (check-equal? (rank r0-td) 0) + (check-equal? (rank r1-td) 1) + (check-equal? (rank r2-td) 2) + + (check-equal? (size-of '()) 1) + (check-equal? (size-of test-shape) 12) + + + (check-equal? (size-of '(4 3 2)) 24) + + (check-tensor-equal? (reshape '(24) r3-td) + (tensor 0 1 2 3 4 5 + 6 7 8 9 10 11 + 12 13 14 15 16 17 + 18 19 20 21 22 23)) + + (check-tensor-equal? (reshape '(4 1) (tensor 0 1 2 3)) + (tensor (tensor 0) (tensor 1) (tensor 2) (tensor 3))) + + (check-tensor-equal? (reshape '(6) r2-td) + (tensor 3.0 4.0 5.0 7.0 8.0 9.0)) + + (check-tensor-equal? (reshape '(3 2) r2-td) + (tensor (tensor 3.0 4.0) + (tensor 5.0 7.0) + (tensor 8.0 9.0))) + + + (check-exn exn:fail? + (λ () + (tensor "1 2" 1 2))) + + (check-exn exn:fail? + (λ () + (tensor))) + + (check-exn exn:fail? + (λ () + (tensor 1 (tensor 2 3)))) + + (check-exn exn:fail? + (λ () + (tensor tensor (tensor 2 3)))) +) diff --git a/uniform-tensors/tensors/test/test-D-extend.rkt b/uniform-tensors/tensors/test/test-D-extend.rkt new file mode 100644 index 0000000..3fd924d --- /dev/null +++ b/uniform-tensors/tensors/test/test-D-extend.rkt @@ -0,0 +1,234 @@ +(module+ test + (require rackunit) + + (define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + + (define sum-shape-f + (λ (in-f-shape) + '())) + + (define sum (ext1-ρ sum-f 1 sum-shape-f)) + + (check-equal? (min-shape 2 '(3 4 5 6)) '(5 6)) + + (check-equal? (min-shape 0 '(3 4 5 6)) '()) + + (check-equal? (merge-shapes '(3 4 5 6) 1 '()) + '(3 4 5)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2 i))) + 0)) + + (check-equal? (flat-store (sum t0)) + #(12.0 44.0 76.0 108.0 140.0 172.0)) + + + (define dup-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (for ([i (in-range 0 sₒ)]) + (vset! out-v (+ iₒ i) + (vref in-v (+ iᵢ (modulo i sᵢ))))))) + + (define dup-shape-f + (λ (in-f-shape) + (list (* 2 (car in-f-shape))))) + + (define dup (ext1-ρ dup-f 1 dup-shape-f)) + (check-equal? (flat-store (dup t0)) + #(0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 + 16 18 20 22 16 18 20 22 + 24 26 28 30 24 26 28 30 + 32 34 36 38 32 34 36 38 + 40 42 44 46 40 42 44 46)) + + (define s0 '(3 4 5 6)) + (define s1 '(3 7 6)) + (define r0 2) + (define r1 1) + + (ext2-shapes s0 s1 r0 r1 '(5 6) + (λ (s-out size-out q0 q1 strides) + (check-equal? s-out '(3 4 7 5 6)) + (check-equal? size-out 2520) + (check-equal? strides '(#(840 120 42) #(210 30 0) #(30 0 6))) + (let-values (((i0 i1) (idxs strides 0 0 0))) + (check-equal? i0 0) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 30 0 0))) + (check-equal? i0 0) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 210 0 0))) + (check-equal? i0 30) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 240 0 0))) + (check-equal? i0 30) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 420 0 0))) + (check-equal? i0 60) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 840 0 0))) + (check-equal? i0 120) + (check-equal? i1 42)) + )) + + + (define *-ρ (ext2-ρ * 0 0)) + (define t0sqr (*-ρ t0 t0)) + + (check-equal? (flat-store t0sqr) + #(0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116)) + + (define *-2-1-f + (λ (v0 i0 s0 v1 i1 s1 vout iout sout) + (for ([j0 (in-range 0 s0)]) + (vset! vout (+ iout j0) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) + + (define t1 + (flat '(5 6) + (build-vec 30 + (λ (i) (* 2.0 i))) + 0)) + + (define t2 + (flat '(6) + (build-vec 6 + (λ (i) (* 3.0 i))) + 0)) + + (define *-2-1 + (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + + (define r-1-2 + (*-2-1 t1 t2)) + + (check-equal? (flat-shape r-1-2) '(5 6)) + (check-equal? (flat-store r-1-2) '#(0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0)) + + (define t3 + (flat '(3 5 6) + (build-vec 90 + (λ (i) (* 2.0 i))) + 0)) + + (define t4 + (flat '(3 6) + (build-vec 18 + (λ (i) (* 3.0 i))) + 0)) + + (define r-3-4 + (*-2-1 t3 t4)) + + (check-equal? (flat-shape r-3-4) '(3 5 6)) + (check-equal? (flat-store r-3-4) + #(0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + +(module+ test + (require rackunit) + + (define r0-td 3.0) + (define r1-td (flat '(3) #(3.0 4.0 5.0) 0)) + (define r2-td (flat '(2 3) #(3.0 4.0 5.0 7.0 8.0 9.0) 0)) + + (define +ᶠ +) + (define +ᵈ (λ (a b z) (values z z))) + + (define sqrᶠ (λ (a) (* a a))) + (define sqrᵈ + (λ (a z) (* z 2 a))) + + (define d-sqr (ext1-∇ sqrᵈ 0 scalar-shape)) + + (define one-like + (λ (t) + (let* ((st (flat-shape t)) + (size-t (size-of st))) + (flat st + (new-vec size-t 1.0) + 0)))) + + (check-equal? (flat-store (d-sqr r1-td (one-like r1-td))) '#(6.0 8.0 10.0)) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-equal? (flat-shape gsqr) '(2 3)) + (check-equal? (flat-store gsqr) '#(6.0 8.0 10.0 14.0 16.0 18.0))) + + (define d+ (ext2-∇ +ᵈ 0 0 scalar-shape)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-equal? (flat-shape da) '(3)) + (check-equal? (flat-store da) '#(1.0 1.0 1.0)) + (check-equal? (flat-shape db) '(3)) + (check-equal? (flat-store db) '#(1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-equal? (flat-shape da) '(3)) + (check-equal? (flat-store da) '#(2.0 2.0 2.0)) + (check-equal? (flat-shape db) '(2 3)) + (check-equal? (flat-store db) '#(1.0 1.0 1.0 1.0 1.0 1.0))) + + (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) + 0 + 0)) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) + (check-equal? gt (tensor 1.0 2.0 3.0)) + (check-equal? gu (tensor 2.0 3.0 4.0))) + + (define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + + (define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()))) + + (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) + 1.0))) + (check-equal? gt (tensor 1.0 1.0 1.0))) + + (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0)))) + (check-equal? gt (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0))))) From 28a22636226e5d971e2769e19398004f466f31c3 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 3 Feb 2024 11:54:26 -0500 Subject: [PATCH 18/83] [add-unif]Fix tolerance issues in malted test cases --- malted/test/test-A-core.rkt | 4 ++-- uniform-tensors/tensors/A-equality.rkt | 13 ++----------- 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/malted/test/test-A-core.rkt b/malted/test/test-A-core.rkt index 79d363d..22aad31 100644 --- a/malted/test/test-A-core.rkt +++ b/malted/test/test-A-core.rkt @@ -39,9 +39,9 @@ (tensor -0.04142 -0.03111)))) (let ((a (tensor 7 8 9))) - (check-dual-equal? (exp a) (tensor 1096.6331 2980.9579 8103.0839)) + (check-dual-equal? (exp a) (tensor 1096.6332 2980.9579 8103.0839)) (check-dual-equal? ((∇¹ exp) a) - (list (tensor 1096.6331 2980.9579 8103.0839))) + (list (tensor 1096.6332 2980.9579 8103.0839))) (check-dual-equal? (log a) (tensor 1.9459 2.0794 2.1972)) (check-dual-equal? ((∇¹ log) a) (list (tensor 0.1428 0.125 0.1111))) (check-dual-equal? (sqrt a) (tensor 2.6457 2.8284 3.0)) diff --git a/uniform-tensors/tensors/A-equality.rkt b/uniform-tensors/tensors/A-equality.rkt index 0d2300a..ffa5664 100644 --- a/uniform-tensors/tensors/A-equality.rkt +++ b/uniform-tensors/tensors/A-equality.rkt @@ -15,10 +15,6 @@ (define tolerance (make-parameter 0.0001)) -;;TODO: Discuss if this is the right fix, because without this 2 tests in -;;A-core.rkt fail. (Also delete the printfs later) -(tolerance 0.001) - (define equal-within-tolerance? (make-parameter (λ (actual expected) @@ -48,7 +44,6 @@ (expected-size (flat-size expected)) (actual-store (flat-store actual)) (expected-store (flat-store expected))) - ;(printf "###(equal-elements? ~a ~a)~n" actual expected) (and (equal? actual-size expected-size) (call/cc (λ (return) (for/fold ([check #t]) @@ -58,14 +53,10 @@ [i-expected (in-range expected-offset (+ expected-offset expected-size))]) - (define actual-elem (vref actual-store i-actual)) - (define expected-elem (vref expected-store i-expected)) - ;(printf "###actual: ~a \texpected: ~a~n" actual-elem expected-elem) - ;(printf "### |actual - expected| = ~a~n" (abs (- actual-elem expected-elem))) (cond (((equal-within-tolerance?) - actual-elem - expected-elem) check) + (vref actual-store i-actual) + (vref expected-store i-expected)) check) (else (return #f))))))))) (define-binary-check (check-tensor-equal? tensor-equal? actual expected)) From d935e4329bc6343e60be78534072a1902b1381b1 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 3 Feb 2024 12:13:28 -0500 Subject: [PATCH 19/83] [add-unif]Fix build-vec in uniform-tensors --- uniform-tensors/tensors/0-vectors.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uniform-tensors/tensors/0-vectors.rkt b/uniform-tensors/tensors/0-vectors.rkt index 2df0e8e..7e3283e 100644 --- a/uniform-tensors/tensors/0-vectors.rkt +++ b/uniform-tensors/tensors/0-vectors.rkt @@ -14,7 +14,7 @@ (define list->vec list->f32vector) (define build-vec (λ (n proc) - (list->vec (map proc (range n))))) + (list->vec (map (compose exact->inexact proc) (range n))))) (define vec->cpointer f32vector->cpointer) (define-for-syntax debug-leaks? #f) From 29195a3c54268f052b4f9941b2cd692f2ceee1ed Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 17:53:29 -0400 Subject: [PATCH 20/83] [add-unif]Generate indices like a GPU kernel --- uniform-tensors/tensors/D-extend.rkt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/uniform-tensors/tensors/D-extend.rkt b/uniform-tensors/tensors/D-extend.rkt index 07d955d..bd45305 100644 --- a/uniform-tensors/tensors/D-extend.rkt +++ b/uniform-tensors/tensors/D-extend.rkt @@ -190,7 +190,8 @@ (size-out (size-of s-out)) (v-out (new-vec size-out 0.0))) (for ([i-out (in-range 0 size-out stride-out)] - [i0 (in-range off0 (+ off0 size0) stride0)]) + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ i-out stride-out) stride0))) (f v0 i0 stride0 v-out i-out stride-out)) (flat s-out v-out 0)))) @@ -212,7 +213,8 @@ (g0 (new-vec size0 0.0))) (for ([iz (in-range 0 size-z stride-z)] - [i0 (in-range off0 (+ off0 size0) stride0)]) + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ iz stride-z) stride0))) (fᵈ g0 v0 i0 stride0 vz iz stride-z)) (flat s0 g0 0)))) From a0c8506c6d8f86fbf9dfd37f731073f8fa66539c Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 28 Feb 2024 18:43:16 -0500 Subject: [PATCH 21/83] [add-unif]Fix test cases for uniform-tensors --- Makefile | 1 + .../autodiff/test/test-E-print.rkt | 50 ++++---- uniform-tensors/ext-ops/K-concat.rkt | 4 +- uniform-tensors/tensors/A-equality.rkt | 2 +- .../tensors/test/test-D-extend.rkt | 119 +++++++++--------- 5 files changed, 90 insertions(+), 86 deletions(-) diff --git a/Makefile b/Makefile index fb1aade..17cf747 100644 --- a/Makefile +++ b/Makefile @@ -232,6 +232,7 @@ MALTED_SOURCES=\ # All the sources together, plus entry points SOURCES=$(LEARNER_SOURCES)\ $(FLAT_SOURCES)\ + $(UNIFORM_SOURCES)\ $(NESTED_SOURCES)\ $(TOOLS_SOURCES)\ $(MALTED_SOURCES)\ diff --git a/uniform-tensors/autodiff/test/test-E-print.rkt b/uniform-tensors/autodiff/test/test-E-print.rkt index 08c7c8a..30f51e6 100644 --- a/uniform-tensors/autodiff/test/test-E-print.rkt +++ b/uniform-tensors/autodiff/test/test-E-print.rkt @@ -18,54 +18,54 @@ deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) - (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable-flat deep-tensor 3) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...))) (check-equal? (make-printable-flat deeper-tensor 3) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...))) (parameterize ((max-tensor-print-length 3)) - (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) (list - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...)))))) diff --git a/uniform-tensors/ext-ops/K-concat.rkt b/uniform-tensors/ext-ops/K-concat.rkt index e8e3a99..f3adea6 100644 --- a/uniform-tensors/ext-ops/K-concat.rkt +++ b/uniform-tensors/ext-ops/K-concat.rkt @@ -17,9 +17,9 @@ (for ([i (in-range 0 stride-out)]) (cond ((< i stride0) - (vector-set! v-out (+ i-out i) (vector-ref v0 (+ i0 i)))) + (vset! v-out (+ i-out i) (vref v0 (+ i0 i)))) (else - (vector-set! v-out (+ i-out i) (vector-ref v1 (+ i1 (- i stride0))))))))) + (vset! v-out (+ i-out i) (vref v1 (+ i1 (- i stride0))))))))) (define concat-base-∇ (λ (g0 g1 v0 i0 stride0 diff --git a/uniform-tensors/tensors/A-equality.rkt b/uniform-tensors/tensors/A-equality.rkt index ffa5664..639533f 100644 --- a/uniform-tensors/tensors/A-equality.rkt +++ b/uniform-tensors/tensors/A-equality.rkt @@ -63,4 +63,4 @@ (include "test/test-A-equality.rkt") -(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal?) +(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal? equal-elements?) diff --git a/uniform-tensors/tensors/test/test-D-extend.rkt b/uniform-tensors/tensors/test/test-D-extend.rkt index 3fd924d..f7c8eca 100644 --- a/uniform-tensors/tensors/test/test-D-extend.rkt +++ b/uniform-tensors/tensors/test/test-D-extend.rkt @@ -1,5 +1,7 @@ (module+ test (require rackunit) + (require "A-equality.rkt") + (require "B-tensor-basics.rkt") (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) @@ -27,8 +29,8 @@ (* 2 i))) 0)) - (check-equal? (flat-store (sum t0)) - #(12.0 44.0 76.0 108.0 140.0 172.0)) + (check-true (equal-elements? (sum t0) + (tensor 12.0 44.0 76.0 108.0 140.0 172.0))) (define dup-f @@ -42,13 +44,13 @@ (list (* 2 (car in-f-shape))))) (define dup (ext1-ρ dup-f 1 dup-shape-f)) - (check-equal? (flat-store (dup t0)) - #(0 2 4 6 0 2 4 6 - 8 10 12 14 8 10 12 14 - 16 18 20 22 16 18 20 22 - 24 26 28 30 24 26 28 30 - 32 34 36 38 32 34 36 38 - 40 42 44 46 40 42 44 46)) + (check-true (equal-elements? (dup t0) + (tensor 0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 + 16 18 20 22 16 18 20 22 + 24 26 28 30 24 26 28 30 + 32 34 36 38 32 34 36 38 + 40 42 44 46 40 42 44 46))) (define s0 '(3 4 5 6)) (define s1 '(3 7 6)) @@ -89,13 +91,14 @@ (define *-ρ (ext2-ρ * 0 0)) (define t0sqr (*-ρ t0 t0)) - (check-equal? (flat-store t0sqr) - #(0 4 16 36 - 64 100 144 196 - 256 324 400 484 - 576 676 784 900 - 1024 1156 1296 1444 - 1600 1764 1936 2116)) + (check-true (equal-elements? + t0sqr + (tensor 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116))) (define *-2-1-f (λ (v0 i0 s0 v1 i1 s1 vout iout sout) @@ -122,12 +125,14 @@ (define r-1-2 (*-2-1 t1 t2)) - (check-equal? (flat-shape r-1-2) '(5 6)) - (check-equal? (flat-store r-1-2) '#(0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0)) + (check-tensor-equal? r-1-2 + (reshape + '(5 6) + (tensor 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) (define t3 (flat '(3 5 6) @@ -145,31 +150,34 @@ (*-2-1 t3 t4)) (check-equal? (flat-shape r-3-4) '(3 5 6)) - (check-equal? (flat-store r-3-4) - #(0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0 - - 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 - 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 - 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 - 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 - 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 - - 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 - 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 - 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 - 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + (check-tensor-equal? r-3-4 + (reshape + '(3 5 6) + (tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) (module+ test (require rackunit) (define r0-td 3.0) - (define r1-td (flat '(3) #(3.0 4.0 5.0) 0)) - (define r2-td (flat '(2 3) #(3.0 4.0 5.0 7.0 8.0 9.0) 0)) + (define r1-td (flat '(3) (list->vec '(3.0 4.0 5.0)) 0)) + (define r2-td (flat '(2 3) (list->vec '(3.0 4.0 5.0 7.0 8.0 9.0)) 0)) (define +ᶠ +) (define +ᵈ (λ (a b z) (values z z))) @@ -188,33 +196,28 @@ (new-vec size-t 1.0) 0)))) - (check-equal? (flat-store (d-sqr r1-td (one-like r1-td))) '#(6.0 8.0 10.0)) + (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) (let ((gsqr (d-sqr r2-td (one-like r2-td)))) - (check-equal? (flat-shape gsqr) '(2 3)) - (check-equal? (flat-store gsqr) '#(6.0 8.0 10.0 14.0 16.0 18.0))) + (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) (define d+ (ext2-∇ +ᵈ 0 0 scalar-shape)) (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) - (check-equal? (flat-shape da) '(3)) - (check-equal? (flat-store da) '#(1.0 1.0 1.0)) - (check-equal? (flat-shape db) '(3)) - (check-equal? (flat-store db) '#(1.0 1.0 1.0))) + (check-tensor-equal? da (tensor 1.0 1.0 1.0)) + (check-tensor-equal? db (tensor 1.0 1.0 1.0))) (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) - (check-equal? (flat-shape da) '(3)) - (check-equal? (flat-store da) '#(2.0 2.0 2.0)) - (check-equal? (flat-shape db) '(2 3)) - (check-equal? (flat-store db) '#(1.0 1.0 1.0 1.0 1.0 1.0))) + (check-tensor-equal? da (tensor 2.0 2.0 2.0)) + (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) 0 0)) (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) - (check-equal? gt (tensor 1.0 2.0 3.0)) - (check-equal? gu (tensor 2.0 3.0 4.0))) + (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) + (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) (define sum-1-∇ (λ (g t it st vz iz sz) @@ -225,10 +228,10 @@ (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) 1.0))) - (check-equal? gt (tensor 1.0 1.0 1.0))) + (check-tensor-equal? gt (tensor 1.0 1.0 1.0))) (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0)))) - (check-equal? gt (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0))))) + (check-tensor-equal? gt (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0))))) From 76eb5f212bafa115ccaae54c99fcc374952763bf Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 7 Mar 2024 11:14:53 -0500 Subject: [PATCH 22/83] [add-unif]Fix concat and add vector pointer offset function --- info.rkt | 8 ++++++-- uniform-tensors/ext-ops/K-concat.rkt | 2 +- uniform-tensors/tensors/0-vectors.rkt | 10 +++++++++- uniform-tensors/tensors/test/test-D-extend.rkt | 1 - 4 files changed, 16 insertions(+), 5 deletions(-) diff --git a/info.rkt b/info.rkt index b00546a..6b803c0 100644 --- a/info.rkt +++ b/info.rkt @@ -1,12 +1,16 @@ #lang info (define collection "malt") -(define deps '(["base" #:version "8.2"] "rackunit-lib")) +(define deps '(["base" #:version "8.2"] + "rackunit-lib" + "opencl" + "xxhash" + "string-interpolation")) (define pkg-desc "A MAchine Learning Toolkit accompanying The Little Learner: A Straight Line to Deep Learning by Daniel P. Friedman and Anurag Mendhekar") (define version "0.1") (define compile-omit-paths (list #rx"test\\\\" #rx"test/")) (define test-omit-paths (list #rx"test\\\\" #rx"test/")) -(define pkg-authors '("Anurag Mendhekar" "Daniel P. Friedman")) +(define pkg-authors '("Anurag Mendhekar" "Daniel P. Friedman" "Darshal Shetty")) (define build-deps '("scribble-lib" "racket-doc" "rackunit-lib")) (define scribblings '(("scribblings/malt.scrbl" (multi-page)))) (define license 'MIT) diff --git a/uniform-tensors/ext-ops/K-concat.rkt b/uniform-tensors/ext-ops/K-concat.rkt index f3adea6..5cc1ee5 100644 --- a/uniform-tensors/ext-ops/K-concat.rkt +++ b/uniform-tensors/ext-ops/K-concat.rkt @@ -53,7 +53,7 @@ (let ((st (shape-ρ t)) (su (shape-ρ u))) (ensure-compatible-shapes n st su) - ((ext2 concat-base-ρ n n concat-shape) t u))))) + ((ext2-ρ concat-base-ρ n n concat-shape) t u))))) (define ensure-compatible-shapes (λ (n st su) diff --git a/uniform-tensors/tensors/0-vectors.rkt b/uniform-tensors/tensors/0-vectors.rkt index 7e3283e..b8c104f 100644 --- a/uniform-tensors/tensors/0-vectors.rkt +++ b/uniform-tensors/tensors/0-vectors.rkt @@ -1,5 +1,6 @@ #lang racket (require ffi/vector) +(require ffi/unsafe) ;;------------------------------------------------ ;; Raw representation of vectors @@ -16,6 +17,13 @@ (λ (n proc) (list->vec (map (compose exact->inexact proc) (range n))))) (define vec->cpointer f32vector->cpointer) +(define vref-cpointer + (λ (v i) + (unless (< i (vlen v)) + (error 'vref-cpointer + "Index ~a out of range [0, ~a]" + i (vlen v))) + (ptr-add (vec->cpointer v) i _float))) (define-for-syntax debug-leaks? #f) (define-syntax when-debug-leaks @@ -39,7 +47,7 @@ [is (in-range isrc (+ n isrc))]) (vset! dest id (vref src is))))) -(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer new-vec) +(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer vref-cpointer new-vec) ;;------------------------------------------------ ;; Memory management for flat-vectors diff --git a/uniform-tensors/tensors/test/test-D-extend.rkt b/uniform-tensors/tensors/test/test-D-extend.rkt index f7c8eca..5b2c8d0 100644 --- a/uniform-tensors/tensors/test/test-D-extend.rkt +++ b/uniform-tensors/tensors/test/test-D-extend.rkt @@ -149,7 +149,6 @@ (define r-3-4 (*-2-1 t3 t4)) - (check-equal? (flat-shape r-3-4) '(3 5 6)) (check-tensor-equal? r-3-4 (reshape '(3 5 6) From 5bdda045e80e6ddc34c8e2c04b9b0dd479bb112e Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:28:42 -0400 Subject: [PATCH 23/83] [add-unif]Minor fixes and a bugged accelerated runtime --- uniform-tensors/tensors/0-vectors.rkt | 10 +++++++++- uniform-tensors/tensors/C-tensor-ops.rkt | 2 +- uniform-tensors/tensors/test/test-D-extend.rkt | 8 ++++---- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/uniform-tensors/tensors/0-vectors.rkt b/uniform-tensors/tensors/0-vectors.rkt index b8c104f..7b73c29 100644 --- a/uniform-tensors/tensors/0-vectors.rkt +++ b/uniform-tensors/tensors/0-vectors.rkt @@ -47,7 +47,15 @@ [is (in-range isrc (+ n isrc))]) (vset! dest id (vref src is))))) -(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer vref-cpointer new-vec) +(define print-vec + (λ (v (off 0) (port (current-output-port))) + (fprintf port "#(") + (for ((i (in-range off (vlen v)))) + (fprintf port "~a " (vref v i))) + (fprintf port ")"))) + +(provide vec? vec vref vset! vlen vcopy print-vec + list->vec build-vec vec->cpointer vref-cpointer new-vec) ;;------------------------------------------------ ;; Memory management for flat-vectors diff --git a/uniform-tensors/tensors/C-tensor-ops.rkt b/uniform-tensors/tensors/C-tensor-ops.rkt index f9056b4..cf32d23 100644 --- a/uniform-tensors/tensors/C-tensor-ops.rkt +++ b/uniform-tensors/tensors/C-tensor-ops.rkt @@ -26,7 +26,7 @@ (cond ((= (size-of s) (flat-size t)) (flat s (flat-store t) (flat-offset t))) - (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + (else (error 'tensor-reshape "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) (include "test/test-C-tensor-ops.rkt") diff --git a/uniform-tensors/tensors/test/test-D-extend.rkt b/uniform-tensors/tensors/test/test-D-extend.rkt index 5b2c8d0..0eb7ec4 100644 --- a/uniform-tensors/tensors/test/test-D-extend.rkt +++ b/uniform-tensors/tensors/test/test-D-extend.rkt @@ -129,10 +129,10 @@ (reshape '(5 6) (tensor 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0))) + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) (define t3 (flat '(3 5 6) From 7c763714cc7ed28307f7efee5763e48b1f33cfd3 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:56:19 -0400 Subject: [PATCH 24/83] =?UTF-8?q?[add-unif]use=20one=20kernel=20in=20ext2-?= =?UTF-8?q?=E2=88=87=20and=20accelerate=20all=20ext*=20function=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- uniform-tensors.rkt | 2 ++ uniform-tensors/ext-ops/E-argmax.rkt | 2 +- uniform-tensors/ext-ops/F-max.rkt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/uniform-tensors.rkt b/uniform-tensors.rkt index e8479a6..6e4b17b 100644 --- a/uniform-tensors.rkt +++ b/uniform-tensors.rkt @@ -8,6 +8,8 @@ (require "uniform-tensors/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen list->tensor tensor build-tensor diff --git a/uniform-tensors/ext-ops/E-argmax.rkt b/uniform-tensors/ext-ops/E-argmax.rkt index d68966d..1a2a285 100644 --- a/uniform-tensors/ext-ops/E-argmax.rkt +++ b/uniform-tensors/ext-ops/E-argmax.rkt @@ -8,7 +8,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vset! v-out i-out - (for/fold ([max 0.0] + (for/fold ([max -inf.0] [max-i -1] #:result max-i) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vref v0 i))) diff --git a/uniform-tensors/ext-ops/F-max.rkt b/uniform-tensors/ext-ops/F-max.rkt index 2d5f5d6..101ffe2 100644 --- a/uniform-tensors/ext-ops/F-max.rkt +++ b/uniform-tensors/ext-ops/F-max.rkt @@ -8,7 +8,7 @@ (λ (v0 i0 stride0 v-out i-out stride-out) (vset! v-out i-out - (for/fold ([max 0.0]) + (for/fold ([max -inf.0]) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vref v0 i))) (cond From 9ec44a843c8797717314ebbc3d7d2c381c6f1484 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:30:17 -0400 Subject: [PATCH 25/83] [add-unif]Add zeroes as a primitive --- uniform-tensors.rkt | 2 +- uniform-tensors/ext-ops.rkt | 2 +- uniform-tensors/ext-ops/A-scalar-ops.rkt | 5 ++++- uniform-tensors/no-duals-no-overrides.rkt | 2 +- uniform-tensors/no-duals.rkt | 2 +- uniform-tensors/no-overrides.rkt | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/uniform-tensors.rkt b/uniform-tensors.rkt index 6e4b17b..461c9f4 100644 --- a/uniform-tensors.rkt +++ b/uniform-tensors.rkt @@ -33,7 +33,7 @@ (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/uniform-tensors/ext-ops.rkt b/uniform-tensors/ext-ops.rkt index 83af7de..fc223f5 100644 --- a/uniform-tensors/ext-ops.rkt +++ b/uniform-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/uniform-tensors/ext-ops/A-scalar-ops.rkt b/uniform-tensors/ext-ops/A-scalar-ops.rkt index b251e80..9096209 100644 --- a/uniform-tensors/ext-ops/A-scalar-ops.rkt +++ b/uniform-tensors/ext-ops/A-scalar-ops.rkt @@ -121,6 +121,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 @@ -132,4 +135,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/uniform-tensors/no-duals-no-overrides.rkt b/uniform-tensors/no-duals-no-overrides.rkt index ac07a7a..07ca22e 100644 --- a/uniform-tensors/no-duals-no-overrides.rkt +++ b/uniform-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ concat-ρ concat-n-ρ flatten-ρ diff --git a/uniform-tensors/no-duals.rkt b/uniform-tensors/no-duals.rkt index 927c8c7..cd1bcaf 100644 --- a/uniform-tensors/no-duals.rkt +++ b/uniform-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/uniform-tensors/no-overrides.rkt b/uniform-tensors/no-overrides.rkt index 35dcbdd..05844b7 100644 --- a/uniform-tensors/no-overrides.rkt +++ b/uniform-tensors/no-overrides.rkt @@ -34,7 +34,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ From 9b2ee1f635a047856e8c96d8b9c42572f38f2922 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:58:57 -0400 Subject: [PATCH 26/83] [add-unif]Fix apply-*-2 --- uniform-tensors/autodiff/B-prims.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uniform-tensors/autodiff/B-prims.rkt b/uniform-tensors/autodiff/B-prims.rkt index e029cc9..277e9ea 100644 --- a/uniform-tensors/autodiff/B-prims.rkt +++ b/uniform-tensors/autodiff/B-prims.rkt @@ -152,7 +152,7 @@ (λ (ρ-fn ra rb shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) @@ -176,7 +176,7 @@ (λ (∇-fn ra rb z shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) From a277adce92e21f8ab87bcfedba275dfab2e9c042 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:31:02 -0400 Subject: [PATCH 27/83] Add self-hosted runner to CI Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 787919a..0376947 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -3,7 +3,8 @@ name: CI jobs: build: name: "Build on Racket '${{ matrix.racket-version }}' (${{ matrix.racket-variant }})" - runs-on: ubuntu-latest + #runs-on: ubuntu-latest + runs-on: self-hosted strategy: matrix: racket-version: ["stable", "current"] From 7e050fe446a87e3713853712a7de7c333a993ba0 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:34:58 -0400 Subject: [PATCH 28/83] Update ci.yml Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0376947..6108996 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,12 +11,12 @@ jobs: racket-variant: ["BC", "CS"] steps: - uses: actions/checkout@v3 - - uses: Bogdanp/setup-racket@v1.9.1 - with: - architecture: x64 - distribution: full - variant: ${{ matrix.racket-variant }} - version: ${{ matrix.racket-version }} + #- uses: Bogdanp/setup-racket@v1.9.1 + # with: + # architecture: x64 + # distribution: full + # variant: ${{ matrix.racket-variant }} + # version: ${{ matrix.racket-version }} - name: Installing malt and its dependencies run: raco pkg install --no-docs --auto --name malt - name: Compiling malt and building its docs From bd48b30c1f396f299f05add8c3a5e151a28840e6 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:39:11 -0400 Subject: [PATCH 29/83] Update ci.yml Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6108996..b55434f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: # variant: ${{ matrix.racket-variant }} # version: ${{ matrix.racket-version }} - name: Installing malt and its dependencies - run: raco pkg install --no-docs --auto --name malt + run: raco pkg install --no-docs --auto --name malt --skip-installed - name: Compiling malt and building its docs run: raco setup --check-pkg-deps --unused-pkg-deps malt - name: Testing malt From 989d1beab5f385a7f8d971e14936a2fd97393fcf Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:50:52 -0400 Subject: [PATCH 30/83] Update ci.yml Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b55434f..b4266c9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,15 +11,5 @@ jobs: racket-variant: ["BC", "CS"] steps: - uses: actions/checkout@v3 - #- uses: Bogdanp/setup-racket@v1.9.1 - # with: - # architecture: x64 - # distribution: full - # variant: ${{ matrix.racket-variant }} - # version: ${{ matrix.racket-version }} - - name: Installing malt and its dependencies - run: raco pkg install --no-docs --auto --name malt --skip-installed - - name: Compiling malt and building its docs - run: raco setup --check-pkg-deps --unused-pkg-deps malt - name: Testing malt - run: raco test -x -p malt + run: ./make From 0313dcd49232a7173d37d349d7f84a3b1b8226d1 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:51:52 -0400 Subject: [PATCH 31/83] Update ci.yml Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b4266c9..2bb5e3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,4 +12,4 @@ jobs: steps: - uses: actions/checkout@v3 - name: Testing malt - run: ./make + run: make From 7ff8886b71851b1b0d58054ec85e10a22cebd6b8 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:59:27 -0400 Subject: [PATCH 32/83] Update ci.yml Signed-off-by: Darshal Shetty --- .github/workflows/ci.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2bb5e3e..2660aac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,13 +2,8 @@ on: [push, pull_request] name: CI jobs: build: - name: "Build on Racket '${{ matrix.racket-version }}' (${{ matrix.racket-variant }})" - #runs-on: ubuntu-latest + name: "Build on Racket CS" runs-on: self-hosted - strategy: - matrix: - racket-version: ["stable", "current"] - racket-variant: ["BC", "CS"] steps: - uses: actions/checkout@v3 - name: Testing malt From bce70676b537692eee52add482726fe710bac490 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 7 Sep 2024 11:31:02 -0400 Subject: [PATCH 33/83] Add self-hosted runner to CI --- .github/workflows/ci.yml | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 787919a..2660aac 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,23 +2,9 @@ on: [push, pull_request] name: CI jobs: build: - name: "Build on Racket '${{ matrix.racket-version }}' (${{ matrix.racket-variant }})" - runs-on: ubuntu-latest - strategy: - matrix: - racket-version: ["stable", "current"] - racket-variant: ["BC", "CS"] + name: "Build on Racket CS" + runs-on: self-hosted steps: - uses: actions/checkout@v3 - - uses: Bogdanp/setup-racket@v1.9.1 - with: - architecture: x64 - distribution: full - variant: ${{ matrix.racket-variant }} - version: ${{ matrix.racket-version }} - - name: Installing malt and its dependencies - run: raco pkg install --no-docs --auto --name malt - - name: Compiling malt and building its docs - run: raco setup --check-pkg-deps --unused-pkg-deps malt - name: Testing malt - run: raco test -x -p malt + run: make From fa4f20d144dce954149522458dfd1baaa653461c Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 7 Mar 2024 16:25:48 -0500 Subject: [PATCH 34/83] [add-acc]Clone flat tensors into accelerated tensors --- Makefile | 50 +++ accelerated-tensors.rkt | 45 ++ accelerated-tensors/autodiff.rkt | 23 + accelerated-tensors/autodiff/A-autodiff.rkt | 120 ++++++ accelerated-tensors/autodiff/B-prims.rkt | 221 ++++++++++ .../autodiff/C-dualized-tensor-ops.rkt | 51 +++ .../autodiff/D-test-helpers.rkt | 47 +++ accelerated-tensors/autodiff/E-print.rkt | 87 ++++ .../autodiff/test/test-A-autodiff.rkt | 15 + .../autodiff/test/test-E-print.rkt | 71 ++++ accelerated-tensors/ext-impl.rkt | 28 ++ accelerated-tensors/ext-ops.rkt | 40 ++ accelerated-tensors/ext-ops/A-scalar-ops.rkt | 135 ++++++ accelerated-tensors/ext-ops/B-comparators.rkt | 85 ++++ accelerated-tensors/ext-ops/C-star-2-1.rkt | 44 ++ accelerated-tensors/ext-ops/D-sum.rkt | 68 +++ accelerated-tensors/ext-ops/E-argmax.rkt | 41 ++ accelerated-tensors/ext-ops/F-max.rkt | 49 +++ accelerated-tensors/ext-ops/G-correlate.rkt | 95 +++++ accelerated-tensors/ext-ops/I-flatten.rkt | 31 ++ accelerated-tensors/ext-ops/K-concat.rkt | 76 ++++ .../ext-ops/test/test-A-scalar-ops.rkt | 122 ++++++ .../ext-ops/test/test-B-comparators.rkt | 12 + .../ext-ops/test/test-C-star-2-1.rkt | 24 ++ .../ext-ops/test/test-D-sum.rkt | 58 +++ .../ext-ops/test/test-E-argmax.rkt | 17 + .../ext-ops/test/test-F-max.rkt | 10 + .../ext-ops/test/test-G-correlate.rkt | 118 ++++++ .../ext-ops/test/test-I-flatten.rkt | 13 + .../ext-ops/test/test-K-concat.rkt | 126 ++++++ accelerated-tensors/no-duals-no-overrides.rkt | 29 ++ accelerated-tensors/no-duals.rkt | 29 ++ accelerated-tensors/no-overrides.rkt | 43 ++ accelerated-tensors/tensors.rkt | 22 + accelerated-tensors/tensors/0-vectors.rkt | 93 +++++ accelerated-tensors/tensors/1-flats.rkt | 73 ++++ accelerated-tensors/tensors/A-equality.rkt | 66 +++ .../tensors/B-tensor-basics.rkt | 184 ++++++++ accelerated-tensors/tensors/C-tensor-ops.rkt | 35 ++ accelerated-tensors/tensors/D-extend.rkt | 393 ++++++++++++++++++ .../tensors/test/test-A-equality.rkt | 84 ++++ .../tensors/test/test-B-tensor-basics.rkt | 33 ++ .../tensors/test/test-C-tensor-ops.rkt | 63 +++ .../tensors/test/test-D-extend.rkt | 236 +++++++++++ impl-no-duals-no-overrides.rkt | 4 + impl-no-duals.rkt | 2 + impl-no-overrides.rkt | 2 + impl.rkt | 2 + set-impl.rkt | 3 +- 49 files changed, 3317 insertions(+), 1 deletion(-) create mode 100644 accelerated-tensors.rkt create mode 100644 accelerated-tensors/autodiff.rkt create mode 100644 accelerated-tensors/autodiff/A-autodiff.rkt create mode 100644 accelerated-tensors/autodiff/B-prims.rkt create mode 100644 accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt create mode 100644 accelerated-tensors/autodiff/D-test-helpers.rkt create mode 100644 accelerated-tensors/autodiff/E-print.rkt create mode 100644 accelerated-tensors/autodiff/test/test-A-autodiff.rkt create mode 100644 accelerated-tensors/autodiff/test/test-E-print.rkt create mode 100644 accelerated-tensors/ext-impl.rkt create mode 100644 accelerated-tensors/ext-ops.rkt create mode 100644 accelerated-tensors/ext-ops/A-scalar-ops.rkt create mode 100644 accelerated-tensors/ext-ops/B-comparators.rkt create mode 100644 accelerated-tensors/ext-ops/C-star-2-1.rkt create mode 100644 accelerated-tensors/ext-ops/D-sum.rkt create mode 100644 accelerated-tensors/ext-ops/E-argmax.rkt create mode 100644 accelerated-tensors/ext-ops/F-max.rkt create mode 100644 accelerated-tensors/ext-ops/G-correlate.rkt create mode 100644 accelerated-tensors/ext-ops/I-flatten.rkt create mode 100644 accelerated-tensors/ext-ops/K-concat.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-B-comparators.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-D-sum.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-E-argmax.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-F-max.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-G-correlate.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-I-flatten.rkt create mode 100644 accelerated-tensors/ext-ops/test/test-K-concat.rkt create mode 100644 accelerated-tensors/no-duals-no-overrides.rkt create mode 100644 accelerated-tensors/no-duals.rkt create mode 100644 accelerated-tensors/no-overrides.rkt create mode 100644 accelerated-tensors/tensors.rkt create mode 100644 accelerated-tensors/tensors/0-vectors.rkt create mode 100644 accelerated-tensors/tensors/1-flats.rkt create mode 100644 accelerated-tensors/tensors/A-equality.rkt create mode 100644 accelerated-tensors/tensors/B-tensor-basics.rkt create mode 100644 accelerated-tensors/tensors/C-tensor-ops.rkt create mode 100644 accelerated-tensors/tensors/D-extend.rkt create mode 100644 accelerated-tensors/tensors/test/test-A-equality.rkt create mode 100644 accelerated-tensors/tensors/test/test-B-tensor-basics.rkt create mode 100644 accelerated-tensors/tensors/test/test-C-tensor-ops.rkt create mode 100644 accelerated-tensors/tensors/test/test-D-extend.rkt diff --git a/Makefile b/Makefile index 17cf747..ea57e90 100644 --- a/Makefile +++ b/Makefile @@ -13,6 +13,7 @@ TEST_FLAGS=-q LEARNER_DIR=learner FLAT_DIR=flat-tensors UNIFORM_DIR=uniform-tensors +ACCELERATED_DIR=accelerated-tensors NESTED_DIR=nested-tensors TOOLS_DIR=tools MALTED_DIR=malted @@ -158,6 +159,54 @@ UNIFORM_SOURCES=$(UNIFORM_TENSORS_SOURCES)\ $(UNIFORM_EXT_OPS_SOURCES)\ $(UNIFORM_LOADERS) +# accelerated-tensors +ACCELERATED_TENSORS_DIR=$(ACCELERATED_DIR)/tensors +ACCELERATED_AUTODIFF_DIR=$(ACCELERATED_DIR)/autodiff +ACCELERATED_EXT_OPS_DIR=$(ACCELERATED_DIR)/ext-ops + +ACCELERATED_TENSORS_SOURCES=\ + $(ACCELERATED_TENSORS_DIR)/0-vectors.rkt\ + $(ACCELERATED_TENSORS_DIR)/1-flats.rkt\ + $(ACCELERATED_TENSORS_DIR)/A-equality.rkt\ + $(ACCELERATED_TENSORS_DIR)/B-tensor-basics.rkt\ + $(ACCELERATED_TENSORS_DIR)/C-tensor-ops.rkt\ + $(ACCELERATED_TENSORS_DIR)/D-extend.rkt\ + $(ACCELERATED_DIR)/tensors.rkt + +ACCELERATED_AUTODIFF_SOURCES=\ + $(ACCELERATED_AUTODIFF_DIR)/A-autodiff.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/B-prims.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(ACCELERATED_AUTODIFF_DIR)/E-print.rkt\ + $(ACCELERATED_DIR)/autodiff.rkt + +ACCELERATED_EXT_OPS_SOURCES=\ + $(ACCELERATED_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/B-comparators.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/D-sum.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/E-argmax.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/F-max.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/G-correlate.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/I-flatten.rkt\ + $(ACCELERATED_EXT_OPS_DIR)/K-concat.rkt\ + $(ACCELERATED_DIR)/ext-ops.rkt + +ACCELERATED_LOADERS=\ + $(ACCELERATED_DIR)/no-duals-no-overrides.rkt\ + $(ACCELERATED_DIR)/no-duals.rkt\ + $(ACCELERATED_DIR)/no-overrides.rkt\ + $(ACCELERATED_DIR)/tensors.rkt\ + $(ACCELERATED_DIR)/autodiff.rkt\ + $(ACCELERATED_DIR)/ext-impl.rkt\ + $(ACCELERATED_DIR)/ext-ops.rkt + +ACCELERATED_SOURCES=$(ACCELERATED_TENSORS_SOURCES)\ + $(ACCELERATED_AUTODIFF_SOURCES)\ + $(ACCELERATED_EXT_OPS_SOURCES)\ + $(ACCELERATED_LOADERS) + # nested-tensors NESTED_TENSORS_DIR=$(NESTED_DIR)/tensors NESTED_AUTODIFF_DIR=$(NESTED_DIR)/autodiff @@ -233,6 +282,7 @@ MALTED_SOURCES=\ SOURCES=$(LEARNER_SOURCES)\ $(FLAT_SOURCES)\ $(UNIFORM_SOURCES)\ + $(ACCELERATED_SOURCES)\ $(NESTED_SOURCES)\ $(TOOLS_SOURCES)\ $(MALTED_SOURCES)\ diff --git a/accelerated-tensors.rkt b/accelerated-tensors.rkt new file mode 100644 index 0000000..158fb6a --- /dev/null +++ b/accelerated-tensors.rkt @@ -0,0 +1,45 @@ +#lang racket/base + +(require + (except-in "accelerated-tensors/tensors.rkt" + rank shape reshape tref trefs tensor? tlen ref refr)) + +(require "accelerated-tensors/autodiff.rkt") +(require "accelerated-tensors/ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + max-tensor-print-length make-printable + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + sum-1 argmax-1 max-1 flatten-2 concat-1-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/autodiff.rkt b/accelerated-tensors/autodiff.rkt new file mode 100644 index 0000000..68168c9 --- /dev/null +++ b/accelerated-tensors/autodiff.rkt @@ -0,0 +1,23 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-tref tref) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide max-tensor-print-length make-printable) diff --git a/accelerated-tensors/autodiff/A-autodiff.rkt b/accelerated-tensors/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..03b9c93 --- /dev/null +++ b/accelerated-tensors/autodiff/A-autodiff.rkt @@ -0,0 +1,120 @@ +#lang racket + +(require "../tensors.rkt") + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (vector-ref d 1)) + (else d)))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (vector? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain map* + trace-print) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt new file mode 100644 index 0000000..e029cc9 --- /dev/null +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -0,0 +1,221 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(define ρ-function + (λ (f) (f ρ-function))) + +(define ∇-function + (λ (f) (f ∇-function))) + +(define shape-fn + (λ (f) (f shape-fn))) + +;; For flat tensors, ρ-fn and ∇-fn +;; are of two types: functional and pre-allocated +;; When they are functional, they return values +;; When they are pre-allocated, they expect expect the +;; return flat-store to be pre-allocated, and simply +;; operate as fillers. +;; +;; Pre-allocated ρ and ∇ have arities +;; 6 and 7 for unary ops, and 9 and 10 for binary ops. +;; We test for this arity to determine the type. +;; +;; Generally speaking, scalar operations are functional +;; and vector operations are pre-allocated. +;; +;; The functions ensure-ρ-callable-1, ensure-∇-callable-1 +;; and ensure-ρ-callable-2, ensure-∇-callable-2 provide +;; the preallocation for flat-stores when a vector-op is +;; provided, but the invocation of prim1 expects functional +;; results. +;; + +(define prim1 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) + (∇-callable (ensure-∇-callable-1 ∇-fn shape))) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim1-dual ρ-callable ∇-callable daf))))))) + +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (let ((ga (∇-fn ra z))) + ((κ da) da ga σ))))))) + +(define prim2 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) + (∇-callable (ensure-∇-callable-2 ∇-fn shape))) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (let-values (((ga gb) (∇-fn ra rb z))) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat)))))))) + +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define ensure-ρ-callable-1 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-1 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn))) + (else ∇-fn)))) + +(define ensure-ρ-callable-2 + (λ (ρ-fn shape-fn) + (cond + ((expects-preallocated? ρ-fn) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn))) + (else ρ-fn)))) + +(define ensure-∇-callable-2 + (λ (∇-fn shape-fn) + (cond + ((expects-preallocated? ∇-fn) + (λ (ra rb z) + (apply-flat-∇-fn-1 ∇-fn ra rb z shape-fn))) + (else ∇-fn)))) + +(define apply-flat-ρ-fn-1 + (λ (ρ-fn ra shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn (flat-store ra) (flat-offset ra) in-size + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-1 + (λ (∇-fn ra z shape-fn) + (let* ((in-shape (flat-shape ra)) + (in-size (size-of in-shape)) + (out-shape (shape-fn in-shape)) + (out-size (size-of out-shape))) + (let ((g (new-vec in-size 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g (flat-store ra) (flat-offset ra) in-size + v-z 0 1) + (flat in-shape g 0))) + (else + (∇-fn g (flat-store ra) (flat-offset ra) in-size + (flat-store z) (flat-offset z) out-size) + (flat in-shape g 0))))))) + +(define apply-flat-ρ-fn-2 + (λ (ρ-fn ra rb shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape ra)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (cond + ((null? out-shape) + (let ((v-out (new-vec 1 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 1) + (vref v-out 0))) + (else + (let ((v-out (new-vec out-size 0.0))) + (ρ-fn + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-out 0 out-size) + (flat out-shape v-out 0))))))) + +(define apply-flat-∇-fn-2 + (λ (∇-fn ra rb z shape-fn) + (let* ((in-shape-a (flat-shape ra)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape ra)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + v-z 0 1) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))) + (else + (∇-fn g0 g1 + (flat-store ra) (flat-offset ra) in-size-a + (flat-store rb) (flat-offset rb) in-size-b + (flat-store z) (flat-offset z) out-size) + (values + (flat in-shape-a g0 0) + (flat in-shape-b g1 0)))))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- +(define ext1 + (λ (f n) + (prim1 + (ext1-ρ (ρ-function f) n (shape-fn f)) + (ext1-∇ (∇-function f) n (shape-fn f)) + (shape-fn f)))) + +(define ext2 + (λ (f m n) + (prim2 + (ext2-ρ (ρ-function f) m n (shape-fn f)) + (ext2-∇ (∇-function f) m n (shape-fn f)) + (shape-fn f)))) + +(provide prim1 prim2 ext1 ext2) diff --git a/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt b/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..5b02ba3 --- /dev/null +++ b/accelerated-tensors/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,51 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tref + (λ (t i) + (tref (ρ t) i))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/accelerated-tensors/autodiff/D-test-helpers.rkt b/accelerated-tensors/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..5b21797 --- /dev/null +++ b/accelerated-tensors/autodiff/D-test-helpers.rkt @@ -0,0 +1,47 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(require rackunit) + +(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args))) + (cond + ((and (equal-wt? ans (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (ρ g) grads))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (ρ y) ans)))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/accelerated-tensors/autodiff/E-print.rkt b/accelerated-tensors/autodiff/E-print.rkt new file mode 100644 index 0000000..7f4090e --- /dev/null +++ b/accelerated-tensors/autodiff/E-print.rkt @@ -0,0 +1,87 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require "A-autodiff.rkt") +(require "../tensors.rkt") + +(define max-tensor-print-length (make-parameter 5)) + +(struct fake-tensor (members) + #:transparent + #:methods gen:custom-write + ((define write-proc + (λ (fake-tensor port mode) + (let ((n (length (fake-tensor-members fake-tensor)))) + (case mode + ((#t) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (write m port)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + ((#f) + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (display m port) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)) + (else + (display "(tensor " port) + (for ([m (fake-tensor-members fake-tensor)] + [c (in-range 0 n)]) + (if (symbol? m) + (display m port) + (print m port mode)) + (when (< c (- n 1)) + (display " " port))) + (display ")" port)))))))) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((flat? y) (make-printable-flat y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(define make-printable-flat + (λ (y max-length) + (flat->tensor-list + (flat-store y) (flat-offset y) (flat-shape y) + (strides (flat-shape y)) max-length))) + +(define flat->tensor-list + (λ (store offset shape strides max-length) + (cond + ((null? shape) (vref store offset)) + (else + (let ((top-len (car shape)) + (stride (car strides))) + (fake-tensor + (reverse + (call/cc + (λ (return) + (for/fold ((lst '())) ((i (in-range offset (+ offset (* top-len stride)) stride)) + (count (in-naturals 0))) + (cond + ((and (> max-length 0) (= count max-length)) (return (cons '... lst))) + (else + (cons (flat->tensor-list store i (cdr shape) (cdr strides) max-length) + lst))))))))))))) + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/accelerated-tensors/autodiff/test/test-A-autodiff.rkt b/accelerated-tensors/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/accelerated-tensors/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/accelerated-tensors/autodiff/test/test-E-print.rkt b/accelerated-tensors/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..30f51e6 --- /dev/null +++ b/accelerated-tensors/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable-flat long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable-flat deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...))) + + (check-equal? (make-printable-flat deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + '...)) + '...)))))) diff --git a/accelerated-tensors/ext-impl.rkt b/accelerated-tensors/ext-impl.rkt new file mode 100644 index 0000000..6eb7d34 --- /dev/null +++ b/accelerated-tensors/ext-impl.rkt @@ -0,0 +1,28 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ + idxs + scalarize + ensure-flat)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/accelerated-tensors/ext-ops.rkt b/accelerated-tensors/ext-ops.rkt new file mode 100644 index 0000000..83af7de --- /dev/null +++ b/accelerated-tensors/ext-ops.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/accelerated-tensors/ext-ops/A-scalar-ops.rkt b/accelerated-tensors/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..b251e80 --- /dev/null +++ b/accelerated-tensors/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,135 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0 + (prim2 + + (λ (a b z) + (values z z)))) + +(define --0-0 + (prim2 - + (λ (a b z) + (values z (- z))))) + +(define *-0-0 + (prim2 * + (λ (a b z) + (values (* b z) (* a z))))) + +(define /-0-0 + (prim2 / + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))))) + +(define expt-0-0 + (prim2 expt + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))))) + +(define exp-0 + (prim1 exp + (λ (a z) + (* z (exp a))))) + +(define log-0 + (prim1 log + (λ (a z) + (* z (/ 1 a))))) + +(define sqrt-0 + (prim1 sqrt + (λ (x z) + (/ z (* 2 (sqrt x)))))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0 + (prim1 abs-0-ρ abs-0-∇)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) +(define d-sqrt (ext1 sqrt-0 0)) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * 0 0)) +(define +-ρ (ext2-ρ + 0 0)) +(define --ρ (ext2-ρ - 0 0)) +(define /-ρ (ext2-ρ / 0 0)) +(define expt-ρ (ext2-ρ expt 0 0)) + +(define exp-ρ (ext1-ρ exp 0)) +(define log-ρ (ext1-ρ log 0)) +(define abs-ρ (ext1-ρ abs-0-ρ 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) + +(define sqrt-ρ + (λ (a) + (expt-ρ a 1/2))) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) diff --git a/accelerated-tensors/ext-ops/B-comparators.rkt b/accelerated-tensors/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..c42a2cf --- /dev/null +++ b/accelerated-tensors/ext-ops/B-comparators.rkt @@ -0,0 +1,85 @@ +#lang racket + +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f) + (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + +(define extended-comparator + (λ (f) + (ext2 (comparator-prim f) 0 0))) + +(define =-1 + (extended-comparator =)) + +(define <-1 + (extended-comparator <)) + +(define >-1 + (extended-comparator >)) + +(define <=-1 + (extended-comparator <=)) + +(define >=-1 + (extended-comparator >=)) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator !=)) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/ext-ops/C-star-2-1.rkt b/accelerated-tensors/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..5eb0d63 --- /dev/null +++ b/accelerated-tensors/ext-ops/C-star-2-1.rkt @@ -0,0 +1,44 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (let ((a (vref v0 (+ i0 i))) + (b (vref v1 (+ i1 (modulo i stride1)))) + (z (vref vz (+ iz i)))) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) (* z b))) + (vset! g1 (+ i1 (modulo i stride1)) + (+ (vref g1 (+ i1 (modulo i stride1))) (* z a))))))) + +(define *-2-1-shape + (λ (s t) + s)) + +(define *-2-1 + (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape)) + +(define d*-2-1 + (ext2 *-2-1 2 1)) + +(define *-2-1-ρ + (ext2-ρ *-2-1-base-ρ 2 1 *-2-1-shape)) + +(include "test/test-C-star-2-1.rkt") + +(provide *-2-1-ρ d*-2-1) diff --git a/accelerated-tensors/ext-ops/D-sum.rkt b/accelerated-tensors/ext-ops/D-sum.rkt new file mode 100644 index 0000000..44c6b8e --- /dev/null +++ b/accelerated-tensors/ext-ops/D-sum.rkt @@ -0,0 +1,68 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define sum-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([sum 0.0]) ([i (in-range i0 (+ i0 stride0))]) + (+ sum (vref v0 i)))))) + +(define sum-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i + (+ (vref g0 i) z)))))) + +(define sum-shape + (λ (st) + (refr st 1))) + +(define sum-1 + (prim1 sum-1-ρ sum-1-∇ sum-shape)) + +(define d-sum + (ext1 sum-1 1)) + +(define sum-ρ + (ext1-ρ sum-1-ρ 1 sum-shape)) + +(provide d-sum sum-ρ) + +(define sum-cols-2-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (for ((i (in-range 0 stride-out))) + (vset! v-out (+ i i-out) + (for/fold ([sum 0.0]) ([j (in-range i0 (+ i0 stride0) stride-out)]) + (+ sum (vref v0 (+ j i)))))))) + +(define sum-cols-2-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (for ((i (in-range 0 stride-z))) + (for ([j (in-range i0 (+ i0 stride0) stride-z)]) + (vset! g0 (+ i j) + (+ (vref g0 (+ i j)) (vref vz (+ i iz)))))))) + +(define sum-cols-shape + (λ (s) + (refr s 1))) + +(define sum-cols-2 + (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape)) + +(define d-sum-cols + (ext1 sum-cols-2 2)) + +(define sum-cols-ρ + (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape)) + +(include "test/test-D-sum.rkt") + +(provide sum-1 d-sum-cols sum-cols-ρ) diff --git a/accelerated-tensors/ext-ops/E-argmax.rkt b/accelerated-tensors/ext-ops/E-argmax.rkt new file mode 100644 index 0000000..d68966d --- /dev/null +++ b/accelerated-tensors/ext-ops/E-argmax.rkt @@ -0,0 +1,41 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define argmax-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max 0.0] + [max-i -1] #:result max-i) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vset! g0 i 0.0))))) + +(define argmax-shape + (λ (st) + '())) + +(define argmax-1 + (prim1 argmax-1-ρ argmax-1-∇ argmax-shape)) + +(define d-argmax + (ext1 argmax-1 1)) + +(define argmax-ρ + (ext1-ρ argmax-1-ρ 1 argmax-shape)) + +(include "test/test-E-argmax.rkt") + +(provide argmax-1 d-argmax argmax-ρ) diff --git a/accelerated-tensors/ext-ops/F-max.rkt b/accelerated-tensors/ext-ops/F-max.rkt new file mode 100644 index 0000000..2d5f5d6 --- /dev/null +++ b/accelerated-tensors/ext-ops/F-max.rkt @@ -0,0 +1,49 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define max-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vset! v-out i-out + (for/fold ([max 0.0]) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) v) + (else max))))))) + +(define max-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vref vz iz))) + (for/fold ([max -inf.0] + [max-i -1] #:result + (for ([i (in-range i0 (+ i0 stride0))]) + (cond + ((= i (+ i0 max-i)) (vset! g0 i z)) + (else (vset! g0 i 0.0))))) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vref v0 i))) + (cond + ((> v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-∇ max-shape)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ 1 max-shape)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/accelerated-tensors/ext-ops/G-correlate.rkt b/accelerated-tensors/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..9db2109 --- /dev/null +++ b/accelerated-tensors/ext-ops/G-correlate.rkt @@ -0,0 +1,95 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vset! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-∇ nd md qd) + correlate-shape))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/accelerated-tensors/ext-ops/I-flatten.rkt b/accelerated-tensors/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..bf24773 --- /dev/null +++ b/accelerated-tensors/ext-ops/I-flatten.rkt @@ -0,0 +1,31 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-∇ + (λ (t z) + (reshape (shape t) z))) + +(define flatten-shape + (λ (s) + (let ((rows (ref s 0)) + (cols (ref s 1))) + (list (* rows cols))))) + +(define flatten-2 + (prim1 flatten-2-ρ flatten-2-∇ flatten-shape)) + +(define d-flatten + (ext1 flatten-2 2)) + +(define flatten-ρ + (ext1-ρ flatten-2-ρ 2)) + +(include "test/test-I-flatten.rkt") + +(provide flatten-2 d-flatten flatten-ρ) diff --git a/accelerated-tensors/ext-ops/K-concat.rkt b/accelerated-tensors/ext-ops/K-concat.rkt new file mode 100644 index 0000000..5cc1ee5 --- /dev/null +++ b/accelerated-tensors/ext-ops/K-concat.rkt @@ -0,0 +1,76 @@ +#lang racket + +(require "../tensors/0-vectors.rkt") +(require (rename-in (only-in "../tensors.rkt" ext2-ρ tref tlen shape len ref) + (shape shape-ρ))) +(require (only-in "../autodiff.rkt" prim2 ext2 shape)) + +(define concat-shape + (λ (st su) + (cons (+ (ref st 0) (ref su 0)) + (cdr st)))) + +(define concat-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (cond + ((< i stride0) + (vset! v-out (+ i-out i) (vref v0 (+ i0 i)))) + (else + (vset! v-out (+ i-out i) (vref v1 (+ i1 (- i stride0))))))))) + +(define concat-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (cond + ((< i stride0) + (vset! g0 (+ i0 i) + (+ (vref g0 (+ i0 i)) + (vref vz (+ iz i))))) + (else + (vset! g1 (+ i1 (- i stride0)) + (+ (vref g1 (+ i1 (- i stride0))) + (vref vz (+ iz i))))))))) + +(define concat-base + (prim2 concat-base-ρ concat-base-∇ concat-shape)) + +(define d-concat-n + (λ (n) + (λ (t u) + (let ((st (shape t)) + (su (shape u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base n n) t u))))) + +(define concat-n-ρ + (λ (n) + (λ (t u) + (let ((st (shape-ρ t)) + (su (shape-ρ u))) + (ensure-compatible-shapes n st su) + ((ext2-ρ concat-base-ρ n n concat-shape) t u))))) + +(define ensure-compatible-shapes + (λ (n st su) + (let ((rt (len st)) + (ru (len su))) + ;; The shape of the tensor of rank r at rank n-1 + ;; is given by (drop st (+ 1 (- r n))) + (when (not (equal? (drop st (+ 1 (- rt n))) (drop su (+ 1 (- ru n))))) + (error 'concat "Incompatible concat shapes: ~a and ~a at last ~a dimensions" + st su n))))) + +(define d-concat (d-concat-n 1)) +(define concat-ρ (concat-n-ρ 1)) +(define concat-1-1 concat-base) + +(include "test/test-K-concat.rkt") + +(provide concat-1-1 + d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt new file mode 100644 index 0000000..2c13e39 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt @@ -0,0 +1,122 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + ;; Check basic numericals + (let ((a 2) + (b 3)) + (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) + (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) + (check-ρ-∇ (d/ a b) + 2/3 + '(0.3333333333333333 -0.2222222222222222)) + (check-ρ-∇ (d-exp a) (exp 2) (list (exp 2))) + (check-ρ-∇ (d-log a) (log 2) (list 0.5)) + (check-ρ-∇ (d-expt a b) 8 + (list 12.0 5.545177444479562)) + (check-ρ-∇ (d-sqrt a) + (sqrt 2) + (list 0.3535533905932738)) + + (let* ((first-derivative ((∇¹ d-sqr) b))) + (check-dual-equal? first-derivative (list 6.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* 2)) + (d (dual* 2))) + (check-dual-equal? + ((∇¹ z) c d) + (list 2.5 2.0))) + + (check-dual-equal? ((∇¹ z) a a) (list 2.5 2.0))) + + ;; Check numericals with vector-duals + + (let ((a (tensor 2.0 3.0 4.0)) + (b (tensor 3.0 8.0 9.0))) + (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + (check-ρ-∇ (d/ a b) + (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + (check-ρ-∇ (d-expt a b) + (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + (list (tensor 12.0 17496.0 589824.0) + (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + (check-ρ-∇ (d-sqrt a) + (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* (tensor 2.0 3.0 4.0))) + (d (dual* (tensor 2.0 3.0 4.0)))) + (check-dual-equal? + ((∇¹ z) c d) + (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + ;; Check numericals with lists + (let ((x (dual* 3)) + (y (dual* 2))) + (let ((f d*)) + (check-ρ-∇ (d* x y) 6 '(2.0 3.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list (f m m) (f n n)))) + x y) + '(6.0 4.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list + (list (f m m) (f n n)) + (list (f m m) (f n n))))) + x y) + '(12.0 8.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (map (λ (m) (f m m)) + (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list n m)))) + x y) + '(4.0 6.0)))) + + (let ((a 7) + (b (tensor 13))) + (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) + (check-ρ-∇ (d* a b) (tensor 91) (list 13.0 (tensor 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13) (list 0.07692 (tensor -0.04142)))) + + (let ((a 7) + (b (tensor 13 15))) + (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) + (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) + (list 0.14358 (tensor -0.04142 -0.03111))))) diff --git a/accelerated-tensors/ext-ops/test/test-B-comparators.rkt b/accelerated-tensors/ext-ops/test/test-B-comparators.rkt new file mode 100644 index 0000000..9f3fdf5 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-B-comparators.rkt @@ -0,0 +1,12 @@ +(module+ test + (require rackunit) + (let ((a 2) + (b 3)) + (check-true (<-0-0 a b)) + (check-false (>-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt b/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..bbb2c8b --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,24 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-D-sum.rkt b/accelerated-tensors/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..7b07d0d --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,58 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (sum-1 b) 14 + (list (tensor 1.0 1.0 1.0 1.0))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-E-argmax.rkt b/accelerated-tensors/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..f72819e --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,17 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/accelerated-tensors/ext-ops/test/test-F-max.rkt b/accelerated-tensors/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..01ab1a5 --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-F-max.rkt @@ -0,0 +1,10 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/accelerated-tensors/ext-ops/test/test-G-correlate.rkt b/accelerated-tensors/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..417723c --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/accelerated-tensors/ext-ops/test/test-I-flatten.rkt b/accelerated-tensors/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..f7740cf --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,13 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/accelerated-tensors/ext-ops/test/test-K-concat.rkt b/accelerated-tensors/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..b427ced --- /dev/null +++ b/accelerated-tensors/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,126 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/accelerated-tensors/no-duals-no-overrides.rkt b/accelerated-tensors/no-duals-no-overrides.rkt new file mode 100644 index 0000000..ac07a7a --- /dev/null +++ b/accelerated-tensors/no-duals-no-overrides.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/no-duals.rkt b/accelerated-tensors/no-duals.rkt new file mode 100644 index 0000000..927c8c7 --- /dev/null +++ b/accelerated-tensors/no-duals.rkt @@ -0,0 +1,29 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/no-overrides.rkt b/accelerated-tensors/no-overrides.rkt new file mode 100644 index 0000000..35dcbdd --- /dev/null +++ b/accelerated-tensors/no-overrides.rkt @@ -0,0 +1,43 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tref tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt d-sqr + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + flatten-ρ concat-ρ concat-n-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/tensors.rkt b/accelerated-tensors/tensors.rkt new file mode 100644 index 0000000..0daca26 --- /dev/null +++ b/accelerated-tensors/tensors.rkt @@ -0,0 +1,22 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require "tensors/A-equality.rkt") +(require "tensors/B-tensor-basics.rkt") +(require "tensors/C-tensor-ops.rkt") +(require "tensors/D-extend.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇ expects-preallocated?) + +(provide flat flat? flat-shape flat-store flat-offset size-of strides) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) diff --git a/accelerated-tensors/tensors/0-vectors.rkt b/accelerated-tensors/tensors/0-vectors.rkt new file mode 100644 index 0000000..b8c104f --- /dev/null +++ b/accelerated-tensors/tensors/0-vectors.rkt @@ -0,0 +1,93 @@ +#lang racket +(require ffi/vector) +(require ffi/unsafe) + +;;------------------------------------------------ +;; Raw representation of vectors +;;------------------------------------------------ + +(define vec? f32vector?) +(define vec f32vector) +(define make-vec make-f32vector) +(define vref f32vector-ref) +(define vset! f32vector-set!) +(define vlen f32vector-length) +(define list->vec list->f32vector) +(define build-vec + (λ (n proc) + (list->vec (map (compose exact->inexact proc) (range n))))) +(define vec->cpointer f32vector->cpointer) +(define vref-cpointer + (λ (v i) + (unless (< i (vlen v)) + (error 'vref-cpointer + "Index ~a out of range [0, ~a]" + i (vlen v))) + (ptr-add (vec->cpointer v) i _float))) + +(define-for-syntax debug-leaks? #f) +(define-syntax when-debug-leaks + (λ (x) + (syntax-case x () + ((when-debug-leaks expr) + debug-leaks? + #'expr) + ((when-debug-leaks expr) + #'(void))))) + +(define new-vec + (λ (size initial-value [context 'new-vec]) + (let ((m (make-vec size initial-value))) + (when-debug-leaks (manage-flat-vector! m context)) + m))) + +(define vcopy + (λ (dest idest src isrc n) + (for ([id (in-range idest (+ n idest))] + [is (in-range isrc (+ n isrc))]) + (vset! dest id (vref src is))))) + +(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer vref-cpointer new-vec) + +;;------------------------------------------------ +;; Memory management for flat-vectors +;;------------------------------------------------ + +(define flat-vector-manager + (make-will-executor)) + +(define manage-flat-vector! + (λ (m context) + (set-count! context (add1 (count context))) + (will-register flat-vector-manager m (flat-vector-collector context)))) + +(define flat-vector-collector + (λ (context) + (λ (v) + (cond + ((vector? v) + (set-count! context (sub1 (count context)))) + (else (fprintf (current-error-port) "?? ...")))))) + +(define start-vector-manager + (λ () + (when-debug-leaks + (void + (thread + (λ () + (let loop () + (will-execute flat-vector-manager) + (loop)))))))) + +(define counts (make-hash)) +(define count (λ (context) (dict-ref counts context 0))) +(define set-count! (λ (context v) (dict-set! counts context v))) +(define vector-manager-report + (λ () + (fprintf (current-error-port) "----------------------------------------------~%") + (fprintf (current-error-port) "context\t\t\tcount~%") + (for ([(context count) (in-hash counts)]) + (fprintf (current-error-port) "~a\t\t\t~a~%" context count)) + (fprintf (current-error-port) "----------------------------------------------~%"))) + +(provide start-vector-manager vector-manager-report) diff --git a/accelerated-tensors/tensors/1-flats.rkt b/accelerated-tensors/tensors/1-flats.rkt new file mode 100644 index 0000000..23aeb79 --- /dev/null +++ b/accelerated-tensors/tensors/1-flats.rkt @@ -0,0 +1,73 @@ +#lang racket + +;-------------------------------------------------------- +; Representation of tensors +;-------------------------------------------------------- + +;; A flat tensor representation is for a contiguous slice in the backing store. +;; The fields we need: +;; shape : list +;; store : vector +;; offset : start of the contiguous slice. +;; size : number of elements in the contiguous slice +;; strides : Number of elements in each dimension of the tensor. +;; rank: Number of dimensions in the tensor + + + +(define flat + (λ (shape store offset) + (vector flat shape store offset + (size-of shape) + (strides shape) + (length shape)))) + +(define flat? + (λ (v) + (and (vector? v) + (eq? (vector-ref v 0) flat)))) + +(define flat-shape + (λ (f) + (vector-ref f 1))) + +(define flat-store + (λ (f) + (vector-ref f 2))) + +(define flat-offset + (λ (f) + (vector-ref f 3))) + +(define flat-size + (λ (f) + (vector-ref f 4))) + +(define flat-strides + (λ (f) + (vector-ref f 5))) + +(define flat-rank + (λ (f) + (vector-ref f 6))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(provide flat flat? flat-shape flat-store + flat-offset flat-rank flat-strides flat-size + size-of strides) diff --git a/accelerated-tensors/tensors/A-equality.rkt b/accelerated-tensors/tensors/A-equality.rkt new file mode 100644 index 0000000..639533f --- /dev/null +++ b/accelerated-tensors/tensors/A-equality.rkt @@ -0,0 +1,66 @@ +#lang racket + +;;—————————————————–—————————————————–—————————————————– +;; Equality checks for mostly for testing. +;;—————————————————–—————————————————–—————————————————– + +(require "0-vectors.ss") +(require "1-flats.ss") +(require rackunit) + +;;—————————————————–—————————————————–—————————————————– +;; These parameters can be overriden to account for +;; different type of numbers used inside tensors. +;;—————————————————–—————————————————–—————————————————– + +(define tolerance (make-parameter 0.0001)) + +(define equal-within-tolerance? + (make-parameter + (λ (actual expected) + (< (abs (- actual expected)) (tolerance))))) + +;;—————————————————–—————————————————–—————————————————– +;; These are representation specific, but part of the +;; exported interface of the module +;;—————————————————–—————————————————–—————————————————– + +(define tensor-equal? + (λ (actual expected) + (or (equal? actual expected) + (and (real? actual) + (real? expected) + ((equal-within-tolerance?) actual expected)) + (and (flat? actual) + (flat? expected) + (equal? (flat-shape actual) + (flat-shape expected)) + (equal-elements? actual expected))))) + +(define (equal-elements? actual expected) + (let ((actual-offset (flat-offset actual)) + (expected-offset (flat-offset expected)) + (actual-size (flat-size actual)) + (expected-size (flat-size expected)) + (actual-store (flat-store actual)) + (expected-store (flat-store expected))) + (and (equal? actual-size expected-size) + (call/cc (λ (return) + (for/fold ([check #t]) + ([i-actual (in-range actual-offset + (+ actual-offset + actual-size))] + [i-expected (in-range expected-offset + (+ expected-offset + expected-size))]) + (cond + (((equal-within-tolerance?) + (vref actual-store i-actual) + (vref expected-store i-expected)) check) + (else (return #f))))))))) + +(define-binary-check (check-tensor-equal? tensor-equal? actual expected)) + +(include "test/test-A-equality.rkt") + +(provide tolerance equal-within-tolerance? tensor-equal? check-tensor-equal? equal-elements?) diff --git a/accelerated-tensors/tensors/B-tensor-basics.rkt b/accelerated-tensors/tensors/B-tensor-basics.rkt new file mode 100644 index 0000000..63e9a71 --- /dev/null +++ b/accelerated-tensors/tensors/B-tensor-basics.rkt @@ -0,0 +1,184 @@ +#lang racket + +;-------------------------------------------------------- +; Memory management tools for vectors +;-------------------------------------------------------- + +(require "0-vectors.ss") +(require "1-flats.ss") + +;-------------------------------------------------------- +; Lists +;-------------------------------------------------------- +(define ref list-ref) +(define refr drop) +(define len length) + +(provide ref refr len) + +;-------------------------------------------------------- +; Tensor basics +;-------------------------------------------------------- + +(define tref + (λ (t i) + (cond + ((= 1 (flat-rank t)) + (vref (flat-store t) (+ (flat-offset t) i))) + (else + (flat (cdr (flat-shape t)) + (flat-store t) + (+ (flat-offset t) (* i (car (flat-strides t))))))))) + +(define tlen + (λ (t) + (car (flat-shape t)))) + +(define flat-ref-idx + (λ (v indices) + (flat-ref-idx* (flat-offset v) (flat-strides v) indices))) + +(define flat-ref-idx* + (λ (current-idx strides indices) + (cond + ((null? indices) current-idx) + (else + (flat-ref-idx* + (+ current-idx + (* (car indices) (car strides))) + (cdr strides) + (cdr indices)))))) + +(define strides + (λ (shape) + (cond + ((null? shape) '()) + (else (cons (size-of (cdr shape)) + (strides (cdr shape))))))) + +(define size-of + (λ (shape) + (product shape 1))) + +(define product + (λ (lst a) + (cond + ((null? lst) a) + (else (product (cdr lst) (* (car lst) a)))))) + +(define list->tensor + (λ (lst) + (cond + ((null? lst) (error 'list->flat-tensor "No elements found")) + ((number? (car lst)) + (flat (list (length lst)) (list->vec lst) 0)) + (else + (flat-tensor-from-list lst))))) + +(define flat-tensor-from-list + (λ (lst) + (let* ([inner-shape (flat-shape (car lst))] + [inner-size (size-of inner-shape)] + [outer-shape (cons (length lst) inner-shape)] + [size (size-of outer-shape)] + [v (new-vec size 0.0 'from-list)]) + (for ([fl lst] + [i (in-naturals 0)]) + (vcopy v (* i inner-size) + (flat-store fl) (flat-offset fl) + inner-size)) + (flat outer-shape v 0)))) + +(define tensor? + (λ (t) + (or (number? t) + (flat? t)))) + +(define tensor + (λ args + (ensure-shape args) + (cond + ((number? (car args)) (flat (list (length args)) + (list->vec (map exact->inexact args)) + 0)) + (else (merge-flats args))))) + +(define merge-flats + (λ (args) + (let* ((inner-shape (flat-shape (car args))) + (outer (length args)) + + (new-shape (cons outer inner-shape)) + (stride (size-of inner-shape)) + + (new-size (size-of new-shape)) + + (v-out (new-vec new-size +nan.0 'tensor))) + (for ([i-out (in-range outer)] + [arg args]) + (vcopy v-out (* i-out stride) (flat-store arg) (flat-offset arg) stride)) + (flat new-shape v-out 0)))) + +(define ensure-shape + (λ (args) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (flat? x) + (flat-shape x) + '())))) + (unless (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((flat? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (flat? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +(define build-tensor + (λ (shape f) + (let* ((size (size-of shape)) + (v (new-vec size 0.0 'build-tensor)) + (strides (strides shape))) + (fill-flat-tensor v shape strides f 0 '()) + (flat shape v 0)))) + +(define fill-flat-tensor + (λ (dest shape strides f offset tidx) + (cond + ((null? (cdr shape)) + (for ([i (in-range 0 (car shape))]) + (vset! dest (+ offset i) + (exact->inexact (f (append tidx (list i))))))) + (else + (let ((stride (car strides))) + (for ([i (in-range 0 (car shape))]) + (fill-flat-tensor dest + (cdr shape) (cdr strides) f + (+ offset (* i stride)) (append tidx (list i))))))))) + +(define trefs + (λ (t b) + (let* ([st (flat-shape t)] + [est (cdr st)] + [estride (size-of est)] + [nshape (cons (length b) (cdr st))] + [size-out (size-of nshape)] + [v-out (new-vec size-out 0.0 'flat-refs)] + [vt (flat-store t)]) + (for ([ib b] + [i-out (in-range 0 size-out estride)]) + (vcopy v-out i-out vt (* ib estride) estride)) + (flat nshape v-out 0)))) + +(include "test/test-B-tensor-basics.rkt") + +(provide tref tlen list->tensor number? + tensor? tensor build-tensor trefs merge-flats) diff --git a/accelerated-tensors/tensors/C-tensor-ops.rkt b/accelerated-tensors/tensors/C-tensor-ops.rkt new file mode 100644 index 0000000..f9056b4 --- /dev/null +++ b/accelerated-tensors/tensors/C-tensor-ops.rkt @@ -0,0 +1,35 @@ +#lang racket + +(require "1-flats.ss") +(require "B-tensor-basics.ss") + +;;—————————————————– +;; Shape, rank, size-of +;;—————————————————– + +(define shape + (λ (t) + (cond + ((number? t) '()) + (else (flat-shape t))))) + +(define rank + (λ (t) + (len (shape t)))) + +;;—————————————————– +;; Reshape a tensor +;;—————————————————– + +(define reshape + (λ (s t) + (cond + ((= (size-of s) (flat-size t)) + (flat s (flat-store t) (flat-offset t))) + (else (error "Cannot reshape ~a to ~a~%" (flat-shape t) s))))) + + +(include "test/test-C-tensor-ops.rkt") + +(provide rank shape reshape) +(provide size-of strides) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt new file mode 100644 index 0000000..bd45305 --- /dev/null +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -0,0 +1,393 @@ +#lang racket + +(require "0-vectors.ss") +(require "1-flats.ss") +(require "B-tensor-basics.ss") +(require "C-tensor-ops.ss") + +;;—————————————————–—————————————————–—————————————————– +;; Unary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext1-ρ + (λ (f m [shape-fn scalar-shape]) + (λ (t) + (cond + ((number? t) (f t)) + ((expects-preallocated? f) + (scalarize + (flat-ext1-ρ f m shape-fn t))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f m shape-fn t)))))))) + +(define ext1-∇ + (λ (f m [shape-fn scalar-shape]) + (λ (t z) + (cond + ((number? t) (f t z)) + ((expects-preallocated? f) + (scalarize (flat-ext1-∇ f m shape-fn t (ensure-flat z)))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f m shape-fn t (ensure-flat z))))))))) + +(define functional->preallocated-1-ρ + (λ (f base-shape out-shape) + (λ (v0 i0 stride0 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value base-shape v0 i0)))))) + +(define functional->preallocated-1-∇ + (λ (f base-shape out-shape) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value base-shape v0 i0))) + (set-prealloc-∇! g0 i0 base-shape (f a z)))))) + +(define set-prealloc-ρ! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out a)) + (else (v-copy-flat! v-out i-out a))))) + +(define set-prealloc-∇! + (λ (v-out i-out out-shape a) + (cond + ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) + (else (v-add-flat! v-out i-out a))))) + +(define arg-value + (λ (v-shape v i) + (cond + ((null? v-shape) (vref v i)) + (else (flat v-shape v i))))) + + +(define invoke-functional-∇ + (λ (f base-shape v0 i0) + (cond + ((null? base-shape) (f (vref v0 i0))) + (else (f (flat base-shape v0 i0)))))) + +;;—————————————————–—————————————————–—————————————————– +;; Binary Pointwise extension +;;—————————————————–—————————————————–—————————————————– + +(define ext2-ρ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u) + (cond + ((and (number? t) (number? u)) (f t u)) + ((expects-preallocated? f) + (scalarize + (flat-ext2-ρ f m n shape-fn t u))) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn (ensure-flat t) u)))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t (ensure-flat u))))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t u)))))))) + +(define ext2-∇ + (λ (f m n [shape-fn scalar-shape]) + (λ (t u z) + (let ((invoke-flat-ext2-∇ + (λ (f m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f m n shape-fn t u z))) + (values (scalarize da) (scalarize db)))))) + (cond + ((and (number? t) (number? u)) (f t u z)) + ((expects-preallocated? f) + (invoke-flat-ext2-∇ f m n shape-fn t u z)) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn (ensure-flat t) u z))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t (ensure-flat u) z))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f m n shape-fn t u z)))))))) + +(define functional->preallocated-2-ρ + (λ (f t-shape u-shape out-shape) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (set-prealloc-ρ! v-out i-out out-shape + (f (arg-value t-shape v0 i0) + (arg-value u-shape v1 i1)))))) + +(define functional->preallocated-2-∇ + (λ (f t-shape u-shape out-shape) + (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z (arg-value out-shape vz iz)) + (a (arg-value t-shape v0 i0)) + (b (arg-value u-shape v1 i1))) + (let-values (((da db) (f a b z))) + (set-prealloc-∇! g0 i0 t-shape da) + (set-prealloc-∇! g1 i1 u-shape db)))))) + +(define idxs + (λ (strides out-i i0 i1) + (for/fold ([i0 i0] + [i1 i1] + [x out-i] #:result (values i0 i1)) + ([stride strides]) + (let ((idx (quotient x (vector-ref stride 0))) + (next-x (remainder x (vector-ref stride 0)))) + (values (+ i0 (* idx (vector-ref stride 1))) + (+ i1 (* idx (vector-ref stride 2))) + next-x))))) + +(define merge-shapes + (λ (in-shape min-rank out-f-shape) + (append (take in-shape (- (length in-shape) min-rank)) + out-f-shape))) + +(define flat-ext1-ρ + (λ (f min-rank shape-fn t0) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sf-out (shape-fn sf0)) + (stride-out (size-of sf-out)) + (s-out (merge-shapes s0 min-rank sf-out)) + (size-out (size-of s-out)) + (v-out (new-vec size-out 0.0))) + (for ([i-out (in-range 0 size-out stride-out)] + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ i-out stride-out) stride0))) + (f v0 i0 stride0 v-out i-out stride-out)) + (flat s-out v-out 0)))) + +(define flat-ext1-∇ + (λ (fᵈ min-rank shape-fn t0 z) + ;; z has the same shape as the output + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape min-rank s0)) + (stride0 (size-of sf0)) + (size0 (size-of s0)) + + (sz (flat-shape z)) + (size-z (size-of sz)) + (sf-z (shape-fn sf0)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + + (g0 (new-vec size0 0.0))) + (for ([iz (in-range 0 size-z stride-z)] + #;[i0 (in-range off0 (+ off0 size0) stride0)]) + (define i0 (+ off0 (* (/ iz stride-z) stride0))) + (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (flat s0 g0 0)))) + +(define flat-ext2-ρ + (λ (f r0 r1 shape-fn t0 t1) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + + (sf-out (shape-fn sf0 sf1)) + (stride0 (size-of sf0)) + (stride1 (size-of sf1)) + (stride-out (size-of sf-out))) + (ext2-shapes s0 s1 r0 r1 sf-out + (λ (s-out size-out q0 q1 strides) + (let ((out-v (new-vec size-out 0.0))) + (for ([out-i (in-range 0 size-out stride-out)]) + (let-values (((i0 i1) + (idxs strides out-i off0 off1))) + (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))) + (flat s-out out-v 0))))))) + +(define flat-ext2-∇ + (λ (fᵈ r0 r1 shape-fn t0 t1 z) + (let* ((s0 (flat-shape t0)) + (v0 (flat-store t0)) + (off0 (flat-offset t0)) + (sf0 (min-shape r0 s0)) + (stride0 (size-of sf0)) + + (s1 (flat-shape t1)) + (v1 (flat-store t1)) + (off1 (flat-offset t1)) + (sf1 (min-shape r1 s1)) + (stride1 (size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (size-of sf-z)) + (vz (flat-store z)) + (offz (flat-offset z))) + (ext2-shapes s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (size-of s0) 0.0)) + (g1 (new-vec (size-of s1) 0.0))) + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) + (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))) + (values (flat s0 g0 0) + (flat s1 g1 0)))))))) + +(define ext2-shapes + (λ (s0 s1 r0 r1 sf-out k) + (let ((l0 (length s0)) + (l1 (length s1))) + (cond + ((and (= r0 l0) (= r1 l1)) + (k sf-out + (size-of sf-out) + (size-of s0) + (size-of s1) + '())) + + ((= r0 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((= r1 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + ((and (not (null? s0)) + (not (null? s1)) + (= (car s0) (car s1))) + (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out + (desc-both (car s0) k))) + + ((> l1 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((> l0 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + (else (error 'ext + "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" + s0 s1 r0 r1)))))) + +(define desc-both + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + (* q1 d) + (cons (vector qout q0 q1) strides))))) + +(define desc-left + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + (* q0 d) + q1 + (cons (vector qout q0 0) strides))))) + +(define desc-right + (λ (d k) + (λ (s-out qout q0 q1 strides) + (k (cons d s-out) + (* qout d) + q0 + (* q1 d) + (cons (vector qout 0 q1) strides))))) + +(define v-copy-flat! + (λ (vg ig a) + ;; copy elements from a to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (vref va (+ a-offset i))))))) + +(define v-add-flat! + (λ (vg ig a) + ;; copy elements to a to vg while adding them to vg + (let ((va (flat-store a)) + (a-offset (flat-offset a)) + (a-stride (size-of (flat-shape a)))) + (for ([i (in-range 0 a-stride)]) + (vset! vg (+ ig i) + (+ (vref vg (+ ig i)) + (vref va (+ a-offset i)))))))) + +(define expects-preallocated? + (λ (f) + (let ((a (procedure-arity f))) + (and (integer? a) + (>= a 6))))) + +(define ensure-flat + (λ (z) + (cond + ((number? z) + (flat '() (new-vec 1 (exact->inexact z)) 0)) + (else z)))) + +(define scalarize + (λ (t) + (cond + ((null? (flat-shape t)) (vref (flat-store t) 0)) + (else t)))) + +(define min-shape + (λ (min-rank in-shape) + (drop in-shape (- (length in-shape) min-rank)))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +(include "test/test-D-extend.rkt") + +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/accelerated-tensors/tensors/test/test-A-equality.rkt b/accelerated-tensors/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..9488d8c --- /dev/null +++ b/accelerated-tensors/tensors/test/test-A-equality.rkt @@ -0,0 +1,84 @@ +(module+ test + (require rackunit) + + (check-true ((equal-within-tolerance?) 1.00001 1.0001)) + (check-true ((equal-within-tolerance?) 1.0002 1.0001)) + (check-false ((equal-within-tolerance?) 1.0003 1.0001)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.0 i))) + 0)) + + + (define t1 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t2 + (flat '(1 2 3 4) + (build-vec 24 + (λ (i) + (* 2.000001 i))) + 0)) + + (define t3 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (* (quotient i 24) i))) + 0)) + + (define t4 + (flat '(2 2 3 4) + (build-vec 48 + (λ (i) + (- (* 2.000001 (* (quotient i 24) i)) 48.0))) + 0)) + + (check-true (equal-elements? t0 t1)) + + (check-true (equal-elements? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (equal-elements? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (equal-elements? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (equal-elements? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-true (tensor-equal? t0 t1)) + + (check-false (tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0))) + + (check-false (tensor-equal? t1 (flat '(2 3 4) + (flat-store t3) + 24))) + + (check-true (tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) + + (check-tensor-equal? t0 t1) + + (check-tensor-equal? t0 (flat '(2 3 4) + (flat-store t2) + 0)) + + (check-tensor-equal? t1 (flat '(2 3 4) + (flat-store t4) + 24))) diff --git a/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt b/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt new file mode 100644 index 0000000..903220b --- /dev/null +++ b/accelerated-tensors/tensors/test/test-B-tensor-basics.rkt @@ -0,0 +1,33 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (check-tensor-equal? (tref r1-td 2) 5.0) + (check-equal? (tlen r1-td) 3) + (check-tensor-equal? (list->tensor (list 3.0 4.0 5.0)) r1-td) + + (check-true (and (tensor? r0-td) (tensor? r1-td))) + (check-false (tensor? '(a b c))) + + (check-tensor-equal? (build-tensor '(4 3 2) + (λ (idx) + (+ (* 6 (ref idx 0)) + (* 2 (ref idx 1)) + (ref idx 2)))) + r3-td) + + (check-tensor-equal? (build-tensor '(1 2 3) (λ (idx) (+ (list-ref idx 0) (list-ref idx 1) (list-ref idx 2)))) + (tensor (tensor (tensor 0 1 2) (tensor 1 2 3)))) + + (check-tensor-equal? (trefs r1-td '(0 2)) (tensor 3.0 5.0)) + + ) diff --git a/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt b/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt new file mode 100644 index 0000000..4509a31 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-C-tensor-ops.rkt @@ -0,0 +1,63 @@ +(module+ test + (require rackunit) + (require "A-equality.ss") + + (define r0-td 3.0) + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tensor (tensor 3.0 4.0 5.0) (tensor 7.0 8.0 9.0))) + (define r3-td + (tensor (tensor (tensor 0 1) (tensor 2 3) (tensor 4 5)) + (tensor (tensor 6 7) (tensor 8 9) (tensor 10 11)) + (tensor (tensor 12 13) (tensor 14 15) (tensor 16 17)) + (tensor (tensor 18 19) (tensor 20 21) (tensor 22 23)))) + + (define test-shape (list 2 2 3)) + + (check-equal? (shape r0-td) (list)) + (check-equal? (shape r1-td) (list 3)) + (check-equal? (shape r2-td) (list 2 3)) + + (check-equal? (rank r0-td) 0) + (check-equal? (rank r1-td) 1) + (check-equal? (rank r2-td) 2) + + (check-equal? (size-of '()) 1) + (check-equal? (size-of test-shape) 12) + + + (check-equal? (size-of '(4 3 2)) 24) + + (check-tensor-equal? (reshape '(24) r3-td) + (tensor 0 1 2 3 4 5 + 6 7 8 9 10 11 + 12 13 14 15 16 17 + 18 19 20 21 22 23)) + + (check-tensor-equal? (reshape '(4 1) (tensor 0 1 2 3)) + (tensor (tensor 0) (tensor 1) (tensor 2) (tensor 3))) + + (check-tensor-equal? (reshape '(6) r2-td) + (tensor 3.0 4.0 5.0 7.0 8.0 9.0)) + + (check-tensor-equal? (reshape '(3 2) r2-td) + (tensor (tensor 3.0 4.0) + (tensor 5.0 7.0) + (tensor 8.0 9.0))) + + + (check-exn exn:fail? + (λ () + (tensor "1 2" 1 2))) + + (check-exn exn:fail? + (λ () + (tensor))) + + (check-exn exn:fail? + (λ () + (tensor 1 (tensor 2 3)))) + + (check-exn exn:fail? + (λ () + (tensor tensor (tensor 2 3)))) +) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt new file mode 100644 index 0000000..5b2c8d0 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -0,0 +1,236 @@ +(module+ test + (require rackunit) + (require "A-equality.rkt") + (require "B-tensor-basics.rkt") + + (define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + + (define sum-shape-f + (λ (in-f-shape) + '())) + + (define sum (ext1-ρ sum-f 1 sum-shape-f)) + + (check-equal? (min-shape 2 '(3 4 5 6)) '(5 6)) + + (check-equal? (min-shape 0 '(3 4 5 6)) '()) + + (check-equal? (merge-shapes '(3 4 5 6) 1 '()) + '(3 4 5)) + + (define t0 + (flat '(2 3 4) + (build-vec 24 + (λ (i) + (* 2 i))) + 0)) + + (check-true (equal-elements? (sum t0) + (tensor 12.0 44.0 76.0 108.0 140.0 172.0))) + + + (define dup-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (for ([i (in-range 0 sₒ)]) + (vset! out-v (+ iₒ i) + (vref in-v (+ iᵢ (modulo i sᵢ))))))) + + (define dup-shape-f + (λ (in-f-shape) + (list (* 2 (car in-f-shape))))) + + (define dup (ext1-ρ dup-f 1 dup-shape-f)) + (check-true (equal-elements? (dup t0) + (tensor 0 2 4 6 0 2 4 6 + 8 10 12 14 8 10 12 14 + 16 18 20 22 16 18 20 22 + 24 26 28 30 24 26 28 30 + 32 34 36 38 32 34 36 38 + 40 42 44 46 40 42 44 46))) + + (define s0 '(3 4 5 6)) + (define s1 '(3 7 6)) + (define r0 2) + (define r1 1) + + (ext2-shapes s0 s1 r0 r1 '(5 6) + (λ (s-out size-out q0 q1 strides) + (check-equal? s-out '(3 4 7 5 6)) + (check-equal? size-out 2520) + (check-equal? strides '(#(840 120 42) #(210 30 0) #(30 0 6))) + (let-values (((i0 i1) (idxs strides 0 0 0))) + (check-equal? i0 0) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 30 0 0))) + (check-equal? i0 0) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 210 0 0))) + (check-equal? i0 30) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 240 0 0))) + (check-equal? i0 30) + (check-equal? i1 6)) + + (let-values (((i0 i1) (idxs strides 420 0 0))) + (check-equal? i0 60) + (check-equal? i1 0)) + + (let-values (((i0 i1) (idxs strides 840 0 0))) + (check-equal? i0 120) + (check-equal? i1 42)) + )) + + + (define *-ρ (ext2-ρ * 0 0)) + (define t0sqr (*-ρ t0 t0)) + + (check-true (equal-elements? + t0sqr + (tensor 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116))) + + (define *-2-1-f + (λ (v0 i0 s0 v1 i1 s1 vout iout sout) + (for ([j0 (in-range 0 s0)]) + (vset! vout (+ iout j0) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) + + (define t1 + (flat '(5 6) + (build-vec 30 + (λ (i) (* 2.0 i))) + 0)) + + (define t2 + (flat '(6) + (build-vec 6 + (λ (i) (* 3.0 i))) + 0)) + + (define *-2-1 + (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + + (define r-1-2 + (*-2-1 t1 t2)) + + (check-tensor-equal? r-1-2 + (reshape + '(5 6) + (tensor 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) + + (define t3 + (flat '(3 5 6) + (build-vec 90 + (λ (i) (* 2.0 i))) + 0)) + + (define t4 + (flat '(3 6) + (build-vec 18 + (λ (i) (* 3.0 i))) + 0)) + + (define r-3-4 + (*-2-1 t3 t4)) + + (check-tensor-equal? r-3-4 + (reshape + '(3 5 6) + (tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) + +(module+ test + (require rackunit) + + (define r0-td 3.0) + (define r1-td (flat '(3) (list->vec '(3.0 4.0 5.0)) 0)) + (define r2-td (flat '(2 3) (list->vec '(3.0 4.0 5.0 7.0 8.0 9.0)) 0)) + + (define +ᶠ +) + (define +ᵈ (λ (a b z) (values z z))) + + (define sqrᶠ (λ (a) (* a a))) + (define sqrᵈ + (λ (a z) (* z 2 a))) + + (define d-sqr (ext1-∇ sqrᵈ 0 scalar-shape)) + + (define one-like + (λ (t) + (let* ((st (flat-shape t)) + (size-t (size-of st))) + (flat st + (new-vec size-t 1.0) + 0)))) + + (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + + (define d+ (ext2-∇ +ᵈ 0 0 scalar-shape)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-tensor-equal? da (tensor 1.0 1.0 1.0)) + (check-tensor-equal? db (tensor 1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-tensor-equal? da (tensor 2.0 2.0 2.0)) + (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + + (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) + 0 + 0)) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) + (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) + (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) + + (define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + + (define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()))) + + (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) + 1.0))) + (check-tensor-equal? gt (tensor 1.0 1.0 1.0))) + + (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0)))) + (check-tensor-equal? gt (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0))))) diff --git a/impl-no-duals-no-overrides.rkt b/impl-no-duals-no-overrides.rkt index 2fe6658..8073099 100644 --- a/impl-no-duals-no-overrides.rkt +++ b/impl-no-duals-no-overrides.rkt @@ -13,10 +13,14 @@ #,(case (tensor-implementation) ((learner) #'(require "learner/no-duals-no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals-no-overrides.rkt")) + ((uniform-tensors) #'(require "uniform-tensors/no-duals-no-overrides.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-duals-no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals-no-overrides.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner/no-duals-no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals-no-overrides.rkt"))) + ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals-no-overrides.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-duals-no-overrides.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-duals-no-overrides.rkt"))))))) (load-tensors) diff --git a/impl-no-duals.rkt b/impl-no-duals.rkt index 5384fca..ba72d04 100644 --- a/impl-no-duals.rkt +++ b/impl-no-duals.rkt @@ -14,11 +14,13 @@ ((learner) #'(require "learner/no-duals.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals.rkt")) ((uniform-tensors) #'(require "uniform-tensors/no-duals.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-duals.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner/no-duals.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-duals.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-duals.rkt"))))))) (load-tensors) diff --git a/impl-no-overrides.rkt b/impl-no-overrides.rkt index 56a96c7..b1e6669 100644 --- a/impl-no-overrides.rkt +++ b/impl-no-overrides.rkt @@ -14,11 +14,13 @@ ((learner) #'(require "learner/no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-overrides.rkt")) ((uniform-tensors) #'(require "uniform-tensors/no-overrides.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors/no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-overrides.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner/no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-overrides.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-overrides.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors/no-overrides.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors/no-overrides.rkt"))))))) (load-tensors) diff --git a/impl.rkt b/impl.rkt index 4508b1f..3e47f85 100644 --- a/impl.rkt +++ b/impl.rkt @@ -14,11 +14,13 @@ ((learner) #'(require "learner.rkt")) ((flat-tensors) #'(require "flat-tensors.rkt")) ((uniform-tensors) #'(require "uniform-tensors.rkt")) + ((accelerated-tensors) #'(require "accelerated-tensors.rkt")) ((nested-tensors) #'(require "nested-tensors.rkt"))) #,(case (tensor-implementation) ((learner) #'(provide (all-from-out "learner.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors.rkt"))) + ((accelerated-tensors) #'(provide (all-from-out "accelerated-tensors.rkt"))) ((nested-tensors) #'(provide (all-from-out "nested-tensors.rkt"))))))) (load-tensors) diff --git a/set-impl.rkt b/set-impl.rkt index 712c23d..ded4d92 100644 --- a/set-impl.rkt +++ b/set-impl.rkt @@ -10,7 +10,8 @@ (when (not (member impl '(learner nested-tensors flat-tensors - uniform-tensors))) + uniform-tensors + accelerated-tensors))) (error "Unknown implementation: ~a~%" impl)) (setup #:collections (list (list "malt")) #:clean? #t) (write-implementation-to-config-file impl) From fc7857cbfb891cdb8c413d27c3e6ec8ff096811b Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 18:44:13 -0400 Subject: [PATCH 35/83] =?UTF-8?q?[add-acc]Use=20z=20offset=20in=20flat-ext?= =?UTF-8?q?1-=E2=88=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- uniform-tensors/tensors/D-extend.rkt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/uniform-tensors/tensors/D-extend.rkt b/uniform-tensors/tensors/D-extend.rkt index bd45305..7ff2046 100644 --- a/uniform-tensors/tensors/D-extend.rkt +++ b/uniform-tensors/tensors/D-extend.rkt @@ -210,12 +210,13 @@ (sf-z (shape-fn sf0)) (stride-z (size-of sf-z)) (vz (flat-store z)) + (offz (flat-offset z)) (g0 (new-vec size0 0.0))) (for ([iz (in-range 0 size-z stride-z)] #;[i0 (in-range off0 (+ off0 size0) stride0)]) (define i0 (+ off0 (* (/ iz stride-z) stride0))) - (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z)) (flat s0 g0 0)))) (define flat-ext2-ρ From 3a3a50f673ecfc9dab267bedf1be14e5a87a6661 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 19:30:12 -0400 Subject: [PATCH 36/83] [add-acc]Minor fixes and a bugged accelerated runtime --- Makefile | 1 + accelerated-tensors/tensors/0-vectors.rkt | 10 +- accelerated-tensors/tensors/2-acc-runtime.rkt | 481 ++++++++++++++++++ accelerated-tensors/tensors/D-extend.rkt | 157 ++++-- .../tensors/test/test-D-extend.rkt | 79 ++- impl-loader.rkt | 9 +- 6 files changed, 675 insertions(+), 62 deletions(-) create mode 100644 accelerated-tensors/tensors/2-acc-runtime.rkt diff --git a/Makefile b/Makefile index ea57e90..b5b0bfb 100644 --- a/Makefile +++ b/Makefile @@ -167,6 +167,7 @@ ACCELERATED_EXT_OPS_DIR=$(ACCELERATED_DIR)/ext-ops ACCELERATED_TENSORS_SOURCES=\ $(ACCELERATED_TENSORS_DIR)/0-vectors.rkt\ $(ACCELERATED_TENSORS_DIR)/1-flats.rkt\ + $(ACCELERATED_TENSORS_DIR)/2-acc-runtime.rkt\ $(ACCELERATED_TENSORS_DIR)/A-equality.rkt\ $(ACCELERATED_TENSORS_DIR)/B-tensor-basics.rkt\ $(ACCELERATED_TENSORS_DIR)/C-tensor-ops.rkt\ diff --git a/accelerated-tensors/tensors/0-vectors.rkt b/accelerated-tensors/tensors/0-vectors.rkt index b8c104f..8ef4d37 100644 --- a/accelerated-tensors/tensors/0-vectors.rkt +++ b/accelerated-tensors/tensors/0-vectors.rkt @@ -47,7 +47,15 @@ [is (in-range isrc (+ n isrc))]) (vset! dest id (vref src is))))) -(provide vec? vec vref vset! vlen vcopy list->vec build-vec vec->cpointer vref-cpointer new-vec) +(define print-vec + (λ (v (off 0) (port (current-output-port))) + (fprintf port "#(") + (for ((i (in-range off (vlen v)))) + (fprintf port "~a " (vref v i))) + (fprintf port ")~n"))) + +(provide vec? vec vref vset! vlen vcopy print-vec + list->vec build-vec vec->cpointer vref-cpointer new-vec) ;;------------------------------------------------ ;; Memory management for flat-vectors diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt new file mode 100644 index 0000000..8a96911 --- /dev/null +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -0,0 +1,481 @@ +#lang racket + +(require ffi/cvector + ffi/unsafe + opencl/c + string-interpolation + "0-vectors.rkt") + + +(define context (make-parameter #f)) +(define command-queue (make-parameter #f)) + +(define (cvector->vector cv) + (build-vector (cvector-length cv) + (curry cvector-ref cv))) + +(define (with-opencl th) + (let* ([platform (cvector-ref (clGetPlatformIDs:vector) 0)] + [devices (clGetDeviceIDs:vector platform 'CL_DEVICE_TYPE_GPU)] + [device-idx 0] + [device (cvector-ref devices device-idx)]) + (parameterize* ([context #f] + [command-queue #f]) + (dynamic-wind + (λ () + (context (clCreateContext #f (cvector->vector devices))) + (command-queue (clCreateCommandQueue (context) device '()))) + th + (λ () + (when (command-queue) + (clReleaseCommandQueue (command-queue))) + (when (context) + (clReleaseContext (context)))))))) + +(define (binary-expr rator rand1 rand2) + (string-append "(" rand1 " " rator " " rand2 ")")) + +(define idx-exprs-gen + (λ (strides i0 i1) + (λ (out-i) + (for/fold ([i0 (number->string i0)] + [i1 (number->string i1)] + [x out-i] #:result (values i0 i1)) + ([stride strides]) + (let ((stride-out (number->string (vector-ref stride 0))) + (stride0 (number->string (vector-ref stride 1))) + (stride1 (number->string (vector-ref stride 2)))) + (let ((idx (binary-expr "/" x stride-out)) + (next-x (binary-expr "%" x stride-out))) + (values (binary-expr "+" i0 (binary-expr "*" idx stride0)) + (binary-expr "+" i1 (binary-expr "*" idx stride1)) + next-x))))))) + +(define (ext1-ρ-kernel prim1-ρ-f) + #<bytes/utf-8 + (ext1-ρ-kernel prim-kernel-f))))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf-out) + (clSetKernelArg:_cl_int kernel 3 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf0 + (clReleaseMemObject buf0)))))))) + +(define functional->preallocated-1-ρ-acc + (λ (f-acc base-shape out-shape) + (unless (and (null? base-shape) (null? out-shape)) + (error 'ρ1-functional-non-scalar-acc + (string-append "Functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input and output shape found: ~a ~a") + base-shape out-shape)) + (λ (v0 i0 stride0 v-out i-out stride-out) + (let ((a "@{v0}[@{i0}]")) + #<bytes/utf-8 + (ext1-∇-kernel prim-kernel-f))))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf-g) + (clSetKernelArg:_cl_mem kernel 1 buf0) + (clSetKernelArg:_cl_int kernel 2 stride0) + (clSetKernelArg:_cl_mem kernel 3 buf-z) + (clSetKernelArg:_cl_int kernel 4 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g + (clReleaseMemObject buf-g)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf0 + (clReleaseMemObject buf0)))))))) + +(define functional->preallocated-1-∇-acc + (λ (f-acc base-shape out-shape) + (unless (and (null? base-shape) (null? out-shape)) + (error '∇1-functional-non-scalar-acc + (string-append "Functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input and output shape found: ~a ~a") + base-shape out-shape)) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z "@{vz}[@{iz}]") + (a "@{v0}[@{i0}]")) + #<bytes/utf-8 + (ext2-ρ-kernel prim-kernel-f + (idx-exprs-gen strides 0 0)))))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf1) + (clSetKernelArg:_cl_int kernel 3 stride1) + (clSetKernelArg:_cl_mem kernel 4 buf-out) + (clSetKernelArg:_cl_int kernel 5 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))))) + +(define functional->preallocated-2-ρ-acc + (λ (f-acc t-shape u-shape out-shape) + (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) + (error 'ρ2-functional-non-scalar-acc + (string-append "Functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input 1, input 2 and output shape found: ~a ~a ~a") + t-shape u-shape out-shape)) + (λ (v0 i0 stride0 v1 i1 stride1 v-out i-out stride-out) + (let ((a "@{v0}[@{i0}]") + (b "@{v1}[@{i1}]")) + #<bytes/utf-8 + (ext2-∇-kernel prim-kernel-f + (idx-exprs-gen strides 0 0)))))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf-g0) + (clSetKernelArg:_cl_mem kernel 1 buf-g1) + (clSetKernelArg:_cl_mem kernel 2 buf0) + (clSetKernelArg:_cl_int kernel 3 stride0) + (clSetKernelArg:_cl_mem kernel 4 buf1) + (clSetKernelArg:_cl_int kernel 5 stride1) + (clSetKernelArg:_cl_mem kernel 6 buf-z) + (clSetKernelArg:_cl_int kernel 7 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size1) + (vec->cpointer g1) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g1 + (clReleaseMemObject buf-g1)) + (when buf-g0 + (clReleaseMemObject buf-g0)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))))) + +(define functional->preallocated-2-∇-acc + (λ (f-acc t-shape u-shape out-shape) + (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) + (error '∇2-functional-non-scalar-acc + (string-append "Functional primitives can only accept and" + " return scalars, so try defining a" + " preallocated primitive instead." + " Input 1, input 2 and output shape found: ~a ~a ~a") + t-shape u-shape out-shape)) + (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (let ((z "@{vz}[@{iz}]") + (a "@{v0}[@{i0}]") + (b "@{v1}[@{i1}]")) + (let-values (((da db) (f-acc a b z))) + #<preallocated-1-ρ-acc + run-prim1-∇! functional->preallocated-1-∇-acc + run-prim2-ρ! functional->preallocated-2-ρ-acc + run-prim2-∇! functional->preallocated-2-∇-acc) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index bd45305..12fd7b0 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -1,7 +1,9 @@ #lang racket +(require "../../impl-loader.rkt") (require "0-vectors.ss") (require "1-flats.ss") +(require "2-acc-runtime.ss") (require "B-tensor-basics.ss") (require "C-tensor-ops.ss") @@ -9,35 +11,43 @@ ;; Unary Pointwise extension ;;—————————————————–—————————————————–—————————————————– +;; TODO: Replace accelerate? parameter with a function which determines based on +;; the size of computation workload when to disable GPU acceleration and default +;; to running code on the CPU . + (define ext1-ρ - (λ (f m [shape-fn scalar-shape]) + (λ (f f-acc m [shape-fn scalar-shape]) (λ (t) (cond ((number? t) (f t)) ((expects-preallocated? f) (scalarize - (flat-ext1-ρ f m shape-fn t))) + (flat-ext1-ρ f f-acc m shape-fn t))) (else (let* ((in-shape (flat-shape t)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape out-shape))) (scalarize - (flat-ext1-ρ flat-f m shape-fn t)))))))) + (flat-ext1-ρ flat-f flat-f-acc m shape-fn t)))))))) (define ext1-∇ - (λ (f m [shape-fn scalar-shape]) + (λ (f f-acc m [shape-fn scalar-shape]) (λ (t z) (cond ((number? t) (f t z)) ((expects-preallocated? f) - (scalarize (flat-ext1-∇ f m shape-fn t (ensure-flat z)))) + (scalarize (flat-ext1-∇ f f-acc m shape-fn t (ensure-flat z)))) (else (let* ((in-shape (flat-shape t)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (scalarize (flat-ext1-∇ flat-f m shape-fn t (ensure-flat z))))))))) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc + f-acc base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f flat-f-acc m + shape-fn t (ensure-flat z))))))))) (define functional->preallocated-1-ρ (λ (f base-shape out-shape) @@ -56,19 +66,37 @@ (λ (v-out i-out out-shape a) (cond ((null? out-shape) (vset! v-out i-out a)) - (else (v-copy-flat! v-out i-out a))))) + (else + (error 'ρ-functional-non-scalar-out + (string-append "Functional primitives can only return scalars," + " so try defining a preallocated primitive" + " instead. Out shape found: ~a") + out-shape) + #;(v-copy-flat! v-out i-out a))))) (define set-prealloc-∇! (λ (v-out i-out out-shape a) (cond ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) - (else (v-add-flat! v-out i-out a))))) + (else + (error '∇-functional-non-scalar-out + (string-append "Functional primitives can only return scalars," + " so try defining a preallocated primitive" + " instead. Out shape found: ~a") + out-shape) + #;(v-add-flat! v-out i-out a))))) (define arg-value (λ (v-shape v i) (cond ((null? v-shape) (vref v i)) - (else (flat v-shape v i))))) + (else + (error 'ρ-functional-non-scalar-in + (string-append "Functional primitives can only accept scalars," + " so try defining a preallocated primitive" + " instead. In shape found: ~a") + v-shape) + #;(flat v-shape v i))))) (define invoke-functional-∇ @@ -82,64 +110,70 @@ ;;—————————————————–—————————————————–—————————————————– (define ext2-ρ - (λ (f m n [shape-fn scalar-shape]) + (λ (f f-acc m n [shape-fn scalar-shape]) (λ (t u) (cond ((and (number? t) (number? u)) (f t u)) ((expects-preallocated? f) (scalarize - (flat-ext2-ρ f m n shape-fn t u))) + (flat-ext2-ρ f f-acc m n shape-fn t u))) ((number? t) (let* ((t-shape '()) (u-shape (min-shape n (flat-shape u))) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f m n shape-fn (ensure-flat t) u)))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn (ensure-flat t) u)))) ((number? u) (let* ((t-shape (min-shape m (flat-shape t))) (u-shape '()) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f m n shape-fn t (ensure-flat u))))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn t (ensure-flat u))))) (else (let* ((t-shape (min-shape m (flat-shape t))) (u-shape (min-shape n (flat-shape u))) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f m n shape-fn t u)))))))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn t u)))))))) (define ext2-∇ - (λ (f m n [shape-fn scalar-shape]) + (λ (f f-acc m n [shape-fn scalar-shape]) (λ (t u z) (let ((invoke-flat-ext2-∇ - (λ (f m n shape-fn t u z) - (let-values (((da db) (flat-ext2-∇ f m n shape-fn t u z))) + (λ (f f-acc m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn t u z))) (values (scalarize da) (scalarize db)))))) (cond ((and (number? t) (number? u)) (f t u z)) ((expects-preallocated? f) - (invoke-flat-ext2-∇ f m n shape-fn t u z)) + (invoke-flat-ext2-∇ f f-acc m n shape-fn t u z)) ((number? t) (let* ((t-shape '()) (u-shape (min-shape n (flat-shape u))) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f m n shape-fn (ensure-flat t) u z))) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn (ensure-flat t) u z))) ((number? u) (let* ((t-shape (min-shape m (flat-shape t))) (u-shape '()) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f m n shape-fn t (ensure-flat u) z))) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t (ensure-flat u) z))) (else (let* ((t-shape (min-shape m (flat-shape t))) (u-shape (min-shape n (flat-shape u))) (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f m n shape-fn t u z)))))))) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t u z)))))))) (define functional->preallocated-2-ρ (λ (f t-shape u-shape out-shape) @@ -176,7 +210,7 @@ out-f-shape))) (define flat-ext1-ρ - (λ (f min-rank shape-fn t0) + (λ (f f-acc min-rank shape-fn t0) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) @@ -189,14 +223,18 @@ (s-out (merge-shapes s0 min-rank sf-out)) (size-out (size-of s-out)) (v-out (new-vec size-out 0.0))) - (for ([i-out (in-range 0 size-out stride-out)] - #;[i0 (in-range off0 (+ off0 size0) stride0)]) - (define i0 (+ off0 (* (/ i-out stride-out) stride0))) - (f v0 i0 stride0 v-out i-out stride-out)) + (cond + ((accelerate?) (run-prim1-ρ! f-acc + v0 off0 size0 stride0 + v-out size-out stride-out)) + (else + (for ([i-out (in-range 0 size-out stride-out)]) + (let ((i0 (+ off0 (* (/ i-out stride-out) stride0)))) + (f v0 i0 stride0 v-out i-out stride-out))))) (flat s-out v-out 0)))) (define flat-ext1-∇ - (λ (fᵈ min-rank shape-fn t0 z) + (λ (fᵈ fᵈ-acc min-rank shape-fn t0 z) ;; z has the same shape as the output (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) @@ -210,25 +248,32 @@ (sf-z (shape-fn sf0)) (stride-z (size-of sf-z)) (vz (flat-store z)) + (offz (flat-offset z)) (g0 (new-vec size0 0.0))) - (for ([iz (in-range 0 size-z stride-z)] - #;[i0 (in-range off0 (+ off0 size0) stride0)]) - (define i0 (+ off0 (* (/ iz stride-z) stride0))) - (fᵈ g0 v0 i0 stride0 vz iz stride-z)) + (cond + ((accelerate?) (run-prim1-∇! fᵈ-acc g0 + v0 off0 size0 stride0 + vz offz size-z stride-z)) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let ((i0 (+ off0 (* (/ iz stride-z) stride0)))) + (fᵈ g0 v0 i0 stride0 vz (+ offz iz) stride-z))))) (flat s0 g0 0)))) (define flat-ext2-ρ - (λ (f r0 r1 shape-fn t0 t1) + (λ (f f-acc r0 r1 shape-fn t0 t1) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) (sf0 (min-shape r0 s0)) + (size0 (size-of s0)) (s1 (flat-shape t1)) (v1 (flat-store t1)) (off1 (flat-offset t1)) (sf1 (min-shape r1 s1)) + (size1 (size-of s1)) (sf-out (shape-fn sf0 sf1)) (stride0 (size-of sf0)) @@ -237,24 +282,32 @@ (ext2-shapes s0 s1 r0 r1 sf-out (λ (s-out size-out q0 q1 strides) (let ((out-v (new-vec size-out 0.0))) - (for ([out-i (in-range 0 size-out stride-out)]) - (let-values (((i0 i1) - (idxs strides out-i off0 off1))) - (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))) + (cond + ((accelerate?) (run-prim2-ρ! f-acc strides + v0 off0 size0 stride0 + v1 off1 size1 stride1 + out-v size-out stride-out)) + (else + (for ([out-i (in-range 0 size-out stride-out)]) + (let-values (((i0 i1) + (idxs strides out-i off0 off1))) + (f v0 i0 stride0 v1 i1 stride1 out-v (+ 0 out-i) stride-out))))) (flat s-out out-v 0))))))) (define flat-ext2-∇ - (λ (fᵈ r0 r1 shape-fn t0 t1 z) + (λ (fᵈ fᵈ-acc r0 r1 shape-fn t0 t1 z) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) (sf0 (min-shape r0 s0)) + (size0 (size-of s0)) (stride0 (size-of sf0)) (s1 (flat-shape t1)) (v1 (flat-store t1)) (off1 (flat-offset t1)) (sf1 (min-shape r1 s1)) + (size1 (size-of s1)) (stride1 (size-of sf1)) (sf-z (shape-fn sf0 sf1)) @@ -265,10 +318,16 @@ (λ (sz size-z q0 q1 strides) (let ((g0 (new-vec (size-of s0) 0.0)) (g1 (new-vec (size-of s1) 0.0))) - (for ([iz (in-range 0 size-z stride-z)]) - (let-values (((i0 i1) - (idxs strides iz off0 off1))) - (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))) + (cond + ((accelerate?) (run-prim2-∇! fᵈ-acc strides g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z)) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) + (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) (values (flat s0 g0 0) (flat s1 g1 0)))))))) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt index 5b2c8d0..ae1a786 100644 --- a/accelerated-tensors/tensors/test/test-D-extend.rkt +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require string-interpolation) (require "A-equality.rkt") (require "B-tensor-basics.rkt") @@ -9,11 +10,22 @@ (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) (+ sum (vref in-v i)))))) + (define sum-f-acc + (λ (v0 i0 stride0 v-out i-out stride-out) + #<string i0)] @@ -51,6 +51,57 @@ (binary-expr "+" i1 (binary-expr "*" idx stride1)) next-x))))))) +(define idx-exprs-inv + (λ (strides i-out repeats0 repeats1 s-out) + (λ (i0-var-str i1-var-str i-rep-var-str) + (let ((gen-expr + (λ (i-in-var-str stride-i repeats) + (for/fold ([i-out (number->string i-out)] + [dividend-rep i-rep-var-str] + [predivisor-rep repeats] + [x i-in-var-str] #:result i-out) + ([desc-out s-out] ;; s-out == (append descents-out sf-out) + [stride strides]) ;; (len strides) == (len descents-out) + (let ((stride-out (vector-ref stride 0)) + (stride-in (vector-ref stride stride-i))) + (cond + ((zero? stride-in) + (let* ((divisor-rep (quotient predivisor-rep desc-out)) + (divisor-rep-str (number->string divisor-rep)) + (scaling (binary-expr "/" dividend-rep divisor-rep-str)) + (next-dividend (binary-expr "%" + dividend-rep + divisor-rep-str))) + (values (binary-expr "+" i-out + (binary-expr "*" + scaling + (number->string + stride-out))) + next-dividend + divisor-rep + x))) + (else + (let ((stride-in-str (number->string stride-in))) + (let ((idx (binary-expr "/" x stride-in-str)) + (next-x (binary-expr "%" x stride-in-str))) + (values (binary-expr "+" i-out + (binary-expr "*" idx + (number->string + stride-out))) + dividend-rep + predivisor-rep + next-x)))))))))) + (values (gen-expr i0-var-str 1 repeats0) + (gen-expr i1-var-str 2 repeats1)))))) + +(define calc-repeats + (λ (s0 s1 r0 r1 s-out r-out) + (define size-rep0 (apply * (drop-right s0 r0))) + (define size-rep1 (apply * (drop-right s1 r1))) + (define size-rep-out (apply * (drop-right s-out r-out))) + (values (/ size-rep-out size-rep0) + (/ size-rep-out size-rep1)))) + (define (ext1-ρ-kernel prim1-ρ-f) #<bytes/utf-8 - (ext1-ρ-kernel prim-kernel-f))))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf0) - (clSetKernelArg:_cl_int kernel 1 stride0) - (clSetKernelArg:_cl_mem kernel 2 buf-out) - (clSetKernelArg:_cl_int kernel 3 stride-out)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-out stride-out)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size-out) - (vec->cpointer v-out) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-out - (clReleaseMemObject buf-out)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (λ () + (let* ([buf0 #f] + [buf-out #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf-out (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size-out) + #f)) + (set! program (clCreateProgramWithSource (context) + (make-vector + 1 + (string->bytes/utf-8 + kernel-code)))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf-out) + (clSetKernelArg:_cl_int kernel 3 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf0 + (clReleaseMemObject buf0)))))))) (define functional->preallocated-1-ρ-acc (λ (f-acc base-shape out-shape) @@ -156,67 +206,66 @@ __kernel void Kernel (__global float* g0, EOF ) -(define (run-prim1-∇! prim-kernel-f g0 +(define (run-prim1-∇! kernel-code g0 v0 off0 size0 stride0 vz offz size-z stride-z) (with-opencl - (λ () - (let* ([buf0 #f] - [buf-z #f] - [buf-g #f] - [program #f] - [kernel #f] - [event #f]) - (dynamic-wind - (λ () - ;; Exclude memory consumed by elements before offset of input vector v0 - (set! buf0 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size0) - (vref-cpointer v0 off0))) - (set! buf-z (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size-z) - (vref-cpointer vz offz))) - (set! buf-g (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY - (* (ctype-sizeof _cl_float) - size0) - #f)) - (set! program (clCreateProgramWithSource - (context) - (make-vector - 1 - (string->bytes/utf-8 - (ext1-∇-kernel prim-kernel-f))))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf-g) - (clSetKernelArg:_cl_mem kernel 1 buf0) - (clSetKernelArg:_cl_int kernel 2 stride0) - (clSetKernelArg:_cl_mem kernel 3 buf-z) - (clSetKernelArg:_cl_int kernel 4 stride-z)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-z stride-z)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size0) - (vec->cpointer g0) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-g - (clReleaseMemObject buf-g)) - (when buf-z - (clReleaseMemObject buf-z)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (λ () + (let* ([buf0 #f] + [buf-z #f] + [buf-g #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf-z (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size-z) + (vref-cpointer vz offz))) + (set! buf-g (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size0) + #f)) + + (set! program (clCreateProgramWithSource (context) + (make-vector + 1 + (string->bytes/utf-8 + kernel-code)))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf-g) + (clSetKernelArg:_cl_mem kernel 1 buf0) + (clSetKernelArg:_cl_int kernel 2 stride0) + (clSetKernelArg:_cl_mem kernel 3 buf-z) + (clSetKernelArg:_cl_int kernel 4 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g + (clReleaseMemObject buf-g)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf0 + (clReleaseMemObject buf0)))))))) (define functional->preallocated-1-∇-acc (λ (f-acc base-shape out-shape) @@ -235,8 +284,9 @@ EOF EOF )))) -(define (ext2-ρ-kernel prim2-ρ-f generate-idxs) - (let-values (((i0-expr i1-expr) (generate-idxs "i_out"))) +(define (ext2-ρ-kernel prim2-ρ-f strides) + (let*-values (((generate-idxs) (idx-exprs strides 0 0)) + ((i0-expr i1-expr) (generate-idxs "i_out"))) #<bytes/utf-8 - (ext2-ρ-kernel prim-kernel-f - (idx-exprs-gen strides 0 0)))))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf0) - (clSetKernelArg:_cl_int kernel 1 stride0) - (clSetKernelArg:_cl_mem kernel 2 buf1) - (clSetKernelArg:_cl_int kernel 3 stride1) - (clSetKernelArg:_cl_mem kernel 4 buf-out) - (clSetKernelArg:_cl_int kernel 5 stride-out)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-out stride-out)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size-out) - (vec->cpointer v-out) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-out - (clReleaseMemObject buf-out)) - (when buf1 - (clReleaseMemObject buf1)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (λ () + (let* ([buf0 #f] + [buf1 #f] + [buf-out #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf1 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size1) + (vref-cpointer v1 off1))) + (set! buf-out (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size-out) + #f)) + (set! program (clCreateProgramWithSource + (context) + (make-vector + 1 + (string->bytes/utf-8 kernel-code)))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf1) + (clSetKernelArg:_cl_int kernel 3 stride1) + (clSetKernelArg:_cl_mem kernel 4 buf-out) + (clSetKernelArg:_cl_int kernel 5 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))))) (define functional->preallocated-2-ρ-acc (λ (f-acc t-shape u-shape out-shape) @@ -339,8 +386,13 @@ EOF EOF )))) -(define (ext2-∇-kernel prim2-∇-f generate-idxs) - (let-values (((i0-expr i1-expr) (generate-idxs "iz"))) +(define (ext2-∇-kernel-atomic prim2-∇-f strides) + (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" + "v0" "i0" "stride0" + "v1" "i1" "stride1" + "vz" "iz" "stride_z")) + ((generate-idxs) (idx-exprs strides 0 0)) + ((i0-expr i1-expr) (generate-idxs "iz"))) #<bytes/utf-8 kernel-code)))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf-g0) + (clSetKernelArg:_cl_mem kernel 1 buf-g1) + (clSetKernelArg:_cl_mem kernel 2 buf0) + (clSetKernelArg:_cl_int kernel 3 stride0) + (clSetKernelArg:_cl_mem kernel 4 buf1) + (clSetKernelArg:_cl_int kernel 5 stride1) + (clSetKernelArg:_cl_mem kernel 6 buf-z) + (clSetKernelArg:_cl_int kernel 7 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size1) + (vec->cpointer g1) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g1 + (clReleaseMemObject buf-g1)) + (when buf-g0 + (clReleaseMemObject buf-g0)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))))) + +(define (run-prim2-∇-split! kernel-code0 kernel-code1 g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z) + (with-opencl + (λ () + (define (run! kernel-code g size-in stride-in) + (let* ([buf0 #f] + [buf1 #f] + [buf-z #f] + [buf-g #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + ;; Exclude memory consumed by elements before offset of input vector v0 + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) (* (ctype-sizeof _cl_float) size0) - #f)) - (set! buf-g1 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (vref-cpointer v0 off0))) + (set! buf1 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) (* (ctype-sizeof _cl_float) size1) - #f)) - (printf "###Source:~n~a~n" - (ext2-∇-kernel prim-kernel-f - (idx-exprs-gen strides 0 0))) - (set! program (clCreateProgramWithSource - (context) - (make-vector - 1 - (string->bytes/utf-8 - (ext2-∇-kernel prim-kernel-f - (idx-exprs-gen strides 0 0)))))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf-g0) - (clSetKernelArg:_cl_mem kernel 1 buf-g1) - (clSetKernelArg:_cl_mem kernel 2 buf0) - (clSetKernelArg:_cl_int kernel 3 stride0) - (clSetKernelArg:_cl_mem kernel 4 buf1) - (clSetKernelArg:_cl_int kernel 5 stride1) - (clSetKernelArg:_cl_mem kernel 6 buf-z) - (clSetKernelArg:_cl_int kernel 7 stride-z)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-z stride-z)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size0) - (vec->cpointer g0) (vector event))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size1) - (vec->cpointer g1) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-g1 - (clReleaseMemObject buf-g1)) - (when buf-g0 - (clReleaseMemObject buf-g0)) - (when buf-z - (clReleaseMemObject buf-z)) - (when buf1 - (clReleaseMemObject buf1)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (vref-cpointer v1 off1))) + (set! buf-z (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size-z) + (vref-cpointer vz offz))) + (set! buf-g (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size-in) + #f)) + (set! program (clCreateProgramWithSource + (context) + (make-vector 1 (string->bytes/utf-8 kernel-code)))) + (clBuildProgram program (make-vector 0) (make-bytes 0)) + (set! kernel (clCreateKernel program #"Kernel")) + (clSetKernelArg:_cl_mem kernel 0 buf-g) + (clSetKernelArg:_cl_mem kernel 1 buf0) + (clSetKernelArg:_cl_int kernel 2 stride0) + (clSetKernelArg:_cl_mem kernel 3 buf1) + (clSetKernelArg:_cl_int kernel 4 stride1) + (clSetKernelArg:_cl_mem kernel 5 buf-z) + (clSetKernelArg:_cl_int kernel 6 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-in stride-in)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-in) + (vec->cpointer g) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g + (clReleaseMemObject buf-g)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))) + (run! kernel-code0 g0 size0 stride0) + (run! kernel-code1 g1 size1 stride1)))) (define functional->preallocated-2-∇-acc (λ (f-acc t-shape u-shape out-shape) @@ -464,18 +634,24 @@ EOF " preallocated primitive instead." " Input 1, input 2 and output shape found: ~a ~a ~a") t-shape u-shape out-shape)) - (λ (g0 g1 v0 i0 stride0 v1 i1 stride1 vz iz stride-z) + (λ (g v0 i0 stride0 v1 i1 stride1 vz iz stride-z) (let ((z "@{vz}[@{iz}]") (a "@{v0}[@{i0}]") (b "@{v1}[@{i1}]")) (let-values (((da db) (f-acc a b z))) - #<preallocated-1-ρ-acc - run-prim1-∇! functional->preallocated-1-∇-acc - run-prim2-ρ! functional->preallocated-2-ρ-acc - run-prim2-∇! functional->preallocated-2-∇-acc) +(provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel + run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel + run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel + run-prim2-∇-atomic! run-prim2-∇-split! + functional->preallocated-2-∇-acc + ext2-∇-kernel-atomic ext2-∇-kernel-split) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index 12fd7b0..3421a75 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -224,7 +224,7 @@ (size-out (size-of s-out)) (v-out (new-vec size-out 0.0))) (cond - ((accelerate?) (run-prim1-ρ! f-acc + ((accelerate?) (run-prim1-ρ! (ext1-ρ-kernel f-acc) v0 off0 size0 stride0 v-out size-out stride-out)) (else @@ -252,7 +252,7 @@ (g0 (new-vec size0 0.0))) (cond - ((accelerate?) (run-prim1-∇! fᵈ-acc g0 + ((accelerate?) (run-prim1-∇! (ext1-∇-kernel fᵈ-acc) g0 v0 off0 size0 stride0 vz offz size-z stride-z)) (else @@ -280,10 +280,10 @@ (stride1 (size-of sf1)) (stride-out (size-of sf-out))) (ext2-shapes s0 s1 r0 r1 sf-out - (λ (s-out size-out q0 q1 strides) + (λ (s-out size-out q0 q1 strides parallel-desc?) (let ((out-v (new-vec size-out 0.0))) (cond - ((accelerate?) (run-prim2-ρ! f-acc strides + ((accelerate?) (run-prim2-ρ! (ext2-ρ-kernel f-acc strides) v0 off0 size0 stride0 v1 off1 size1 stride1 out-v size-out stride-out)) @@ -315,18 +315,28 @@ (vz (flat-store z)) (offz (flat-offset z))) (ext2-shapes s0 s1 r0 r1 sf-z - (λ (sz size-z q0 q1 strides) + (λ (sz size-z q0 q1 strides parallel-desc?) (let ((g0 (new-vec (size-of s0) 0.0)) (g1 (new-vec (size-of s1) 0.0))) (cond - ((accelerate?) (run-prim2-∇! fᵈ-acc strides g0 g1 - v0 off0 size0 stride0 - v1 off1 size1 stride1 - vz offz size-z stride-z)) + ((accelerate?) + (cond + (parallel-desc? (run-prim2-∇-atomic! (ext2-∇-kernel-atomic fᵈ-acc strides) + g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z)) + (else + (let*-values (((kernel-code0 kernel-code1) + (ext2-∇-kernel-split fᵈ-acc strides s0 s1 r0 r1 sz (length sf-z)))) + (run-prim2-∇-split! kernel-code0 kernel-code1 + g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z))))) (else (for ([iz (in-range 0 size-z stride-z)]) - (let-values (((i0 i1) - (idxs strides iz off0 off1))) + (let-values (((i0 i1) (idxs strides iz off0 off1))) (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) (values (flat s0 g0 0) (flat s1 g1 0)))))))) @@ -341,7 +351,8 @@ (size-of sf-out) (size-of s0) (size-of s1) - '())) + '() + #t)) ((= r0 l0) (ext2-shapes s0 (cdr s1) r0 r1 sf-out @@ -371,30 +382,33 @@ (define desc-both (λ (d k) - (λ (s-out qout q0 q1 strides) + (λ (s-out qout q0 q1 strides parallel-desc?) (k (cons d s-out) (* qout d) (* q0 d) (* q1 d) - (cons (vector qout q0 q1) strides))))) + (cons (vector qout q0 q1) strides) + parallel-desc?)))) (define desc-left (λ (d k) - (λ (s-out qout q0 q1 strides) + (λ (s-out qout q0 q1 strides parallel-desc?) (k (cons d s-out) (* qout d) (* q0 d) q1 - (cons (vector qout q0 0) strides))))) + (cons (vector qout q0 0) strides) + #f)))) (define desc-right (λ (d k) - (λ (s-out qout q0 q1 strides) + (λ (s-out qout q0 q1 strides parallel-desc?) (k (cons d s-out) (* qout d) q0 (* q1 d) - (cons (vector qout 0 q1) strides))))) + (cons (vector qout 0 q1) strides) + #f)))) (define v-copy-flat! (λ (vg ig a) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt index ae1a786..f26e605 100644 --- a/accelerated-tensors/tensors/test/test-D-extend.rkt +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -94,7 +94,7 @@ EOF (define r1 1) (ext2-shapes s0 s1 r0 r1 '(5 6) - (λ (s-out size-out q0 q1 strides) + (λ (s-out size-out q0 q1 strides parallel-desc?) (check-equal? s-out '(3 4 7 5 6)) (check-equal? size-out 2520) (check-equal? strides '(#(840 120 42) #(210 30 0) #(30 0 6))) @@ -255,7 +255,6 @@ EOF (check-tensor-equal? db (tensor 1.0 1.0 1.0))) (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) - (print-vec (flat-store da)) (check-tensor-equal? da (tensor 2.0 2.0 2.0)) (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) From 80c547115e40080146c8df2787dc3e69bc572b52 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:10:48 -0400 Subject: [PATCH 38/83] =?UTF-8?q?[add-acc]use=20one=20kernel=20in=20ext2-?= =?UTF-8?q?=E2=88=87=20and=20accelerate=20all=20ext*=20function=20calls?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- accelerated-tensors.rkt | 2 + accelerated-tensors/autodiff/A-autodiff.rkt | 5 +- accelerated-tensors/autodiff/B-prims.rkt | 26 +- .../autodiff/D-test-helpers.rkt | 10 +- accelerated-tensors/ext-ops/A-scalar-ops.rkt | 132 +++++++--- accelerated-tensors/ext-ops/B-comparators.rkt | 34 ++- accelerated-tensors/ext-ops/C-star-2-1.rkt | 39 ++- accelerated-tensors/ext-ops/D-sum.rkt | 57 ++++- accelerated-tensors/ext-ops/E-argmax.rkt | 34 ++- accelerated-tensors/ext-ops/F-max.rkt | 43 +++- accelerated-tensors/ext-ops/G-correlate.rkt | 65 ++++- accelerated-tensors/ext-ops/I-flatten.rkt | 25 +- accelerated-tensors/ext-ops/K-concat.rkt | 38 ++- .../ext-ops/test/test-D-sum.rkt | 22 +- .../ext-ops/test/test-E-argmax.rkt | 4 + .../ext-ops/test/test-G-correlate.rkt | 4 +- .../ext-ops/test/test-I-flatten.rkt | 1 + accelerated-tensors/tensors/2-acc-runtime.rkt | 239 ++++++------------ accelerated-tensors/tensors/A-equality.rkt | 7 +- accelerated-tensors/tensors/D-extend.rkt | 45 ++-- impl-loader.rkt | 9 +- malted/test/test-A-core.rkt | 9 +- 22 files changed, 584 insertions(+), 266 deletions(-) diff --git a/accelerated-tensors.rkt b/accelerated-tensors.rkt index 158fb6a..5fa0292 100644 --- a/accelerated-tensors.rkt +++ b/accelerated-tensors.rkt @@ -8,6 +8,8 @@ (require "accelerated-tensors/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen list->tensor tensor build-tensor diff --git a/accelerated-tensors/autodiff/A-autodiff.rkt b/accelerated-tensors/autodiff/A-autodiff.rkt index 03b9c93..afcd114 100644 --- a/accelerated-tensors/autodiff/A-autodiff.rkt +++ b/accelerated-tensors/autodiff/A-autodiff.rkt @@ -1,6 +1,7 @@ #lang racket (require "../tensors.rkt") +(require string-interpolation) ;;---------------------------- ;; Real part of a dual is always a tensor (of any rank) @@ -52,7 +53,7 @@ (hash-set σ d (+-ρ z g))))) (define +-ρ - (ext2-ρ + 0 0)) + (ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0)) ;;---------------------------- ;; Reverse-mode AD @@ -111,7 +112,7 @@ ((dual? v) (trace-print (ρ v) port)) (else (fprintf port "~a~%" v))))) -(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) +(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s)) (include "test/test-A-autodiff.rkt") diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index e029cc9..416e254 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -7,9 +7,15 @@ (define ρ-function (λ (f) (f ρ-function))) +(define ρ-acc-function + (λ (f) (f ρ-acc-function))) + (define ∇-function (λ (f) (f ∇-function))) +(define ∇-acc-function + (λ (f) (f ∇-acc-function))) + (define shape-fn (λ (f) (f shape-fn))) @@ -35,13 +41,15 @@ ;; (define prim1 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) (∇-callable (ensure-∇-callable-1 ∇-fn shape))) (λ (daf) (cond ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) ((eq? daf shape-fn) shape) (else (prim1-dual ρ-callable ∇-callable daf))))))) @@ -54,14 +62,16 @@ ((κ da) da ga σ))))))) (define prim2 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) (∇-callable (ensure-∇-callable-2 ∇-fn shape))) (λ ds (let ((daf (ref ds 0))) (cond ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) ((eq? daf shape-fn) shape) (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) @@ -207,15 +217,19 @@ (define ext1 (λ (f n) (prim1 - (ext1-ρ (ρ-function f) n (shape-fn f)) - (ext1-∇ (∇-function f) n (shape-fn f)) + (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f)) + (ρ-acc-function f) + (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f)) + (∇-acc-function f) (shape-fn f)))) (define ext2 (λ (f m n) (prim2 - (ext2-ρ (ρ-function f) m n (shape-fn f)) - (ext2-∇ (∇-function f) m n (shape-fn f)) + (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f)) + (ρ-acc-function f) + (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f)) + (∇-acc-function f) (shape-fn f)))) (provide prim1 prim2 ext1 ext2) diff --git a/accelerated-tensors/autodiff/D-test-helpers.rkt b/accelerated-tensors/autodiff/D-test-helpers.rkt index 5b21797..208f2b1 100644 --- a/accelerated-tensors/autodiff/D-test-helpers.rkt +++ b/accelerated-tensors/autodiff/D-test-helpers.rkt @@ -2,10 +2,14 @@ (require "../tensors.rkt") (require "A-autodiff.ss") +(require "E-print.ss") (require rackunit) -(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (check-dual-equal? actual expected) + (unless (equal-wt? actual expected) + (fail-check (format "Duals failed to match.~%actual:~%~s~%expected:~s~%" + (make-printable actual) (make-printable expected))))) (define-check (ρ-∇-checker fn args ans grads) (let* ((y (apply fn args)) (g (apply (∇¹ fn) args))) @@ -14,10 +18,10 @@ (equal-wt? grads (ρ g))) (void)) ((equal-wt? ans (ρ y)) (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" - (ρ g) grads))) + (make-printable (ρ g)) (make-printable grads)))) (else (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" - (ρ y) ans)))))) + (make-printable (ρ y)) (make-printable ans))))))) (define-syntax check-ρ-∇ (syntax-rules () diff --git a/accelerated-tensors/ext-ops/A-scalar-ops.rkt b/accelerated-tensors/ext-ops/A-scalar-ops.rkt index b251e80..7893190 100644 --- a/accelerated-tensors/ext-ops/A-scalar-ops.rkt +++ b/accelerated-tensors/ext-ops/A-scalar-ops.rkt @@ -1,49 +1,108 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) (require "../autodiff.rkt") +(define +-0-0-ρ-acc + (λ (a b) + "@{a}+@{b}")) + (define +-0-0 (prim2 + + +-0-0-ρ-acc + (λ (a b z) + (values z z)) (λ (a b z) (values z z)))) +(define --0-0-ρ-acc + (λ (a b) + "@{a}-@{b}")) + (define --0-0 (prim2 - + --0-0-ρ-acc + (λ (a b z) + (values z (- z))) (λ (a b z) - (values z (- z))))) + (values z "(- @{z})")))) + +(define *-0-0-ρ-acc + (λ (a b) + "@{a}*@{b}")) (define *-0-0 (prim2 * + *-0-0-ρ-acc + (λ (a b z) + (values (* b z) (* a z))) (λ (a b z) - (values (* b z) (* a z))))) + (values "@{b}*@{z}" "@{a}*@{z}")))) + +(define /-0-0-ρ-acc + (λ (a b) + "@{a}/@{b}")) (define /-0-0 (prim2 / - (λ (a b z) - (values (* z (/ 1 b)) - (* z (/ (- a) (* b b))))))) + /-0-0-ρ-acc + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))) + (λ (a b z) + (values "(@{z} * (1 / @{b}))" + "(@{z} * ((- @{a}) / (@{b} * @{b})))")))) + +(define expt-0-0-ρ-acc + (λ (a b) + "pow(@{a}, @{b})")) (define expt-0-0 (prim2 expt - (λ (a b z) - (values (* z (* b (expt a (- b 1)))) - (* z (* (expt a b) (log a))))))) + expt-0-0-ρ-acc + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))) + (λ (a b z) + (values "(@{z} * (@{b} * pow(@{a}, (@{b} - 1))))" + "(@{z} * (pow(@{a}, @{b}) * log(@{a})))")))) + +(define exp-0-ρ-acc + (λ (a) + "exp(@{a})")) (define exp-0 (prim1 exp - (λ (a z) - (* z (exp a))))) + exp-0-ρ-acc + (λ (a z) + (* z (exp a))) + (λ (a z) + "(@{z} * exp(@{a}))"))) + +(define log-0-ρ-acc + (λ (a) + "log(@{a})")) (define log-0 (prim1 log - (λ (a z) - (* z (/ 1 a))))) + log-0-ρ-acc + (λ (a z) + (* z (/ 1 a))) + (λ (a z) + "(@{z} * (1 / @{a}))"))) + +(define sqrt-0-ρ-acc + (λ (a) + "sqrt(@{a})")) (define sqrt-0 (prim1 sqrt - (λ (x z) - (/ z (* 2 (sqrt x)))))) + sqrt-0-ρ-acc + (λ (x z) + (/ z (* 2 (sqrt x)))) + (λ (x z) + "(@{z} / (2 * sqrt(@{x})))"))) (define abs-0-ρ (λ (x) @@ -51,14 +110,22 @@ ((< x 0) (* -1 x)) (else x)))) +(define abs-0-ρ-acc + (λ (x) + "fabs(@{x})")) + (define abs-0-∇ (λ (x z) (cond ((< x 0) (- z)) (else z)))) +(define abs-0-∇-acc + (λ (x z) + "sign(@{x}) * @{z}")) + (define abs-0 - (prim1 abs-0-ρ abs-0-∇)) + (prim1 abs-0-ρ abs-0-ρ-acc abs-0-∇ abs-0-∇-acc)) (define rectify-0-ρ (λ (s) @@ -66,17 +133,25 @@ ((< s 0.0) 0.0) (else s)))) +(define rectify-0-ρ-acc + (λ (s) + "fmax(0.0f, @{s})")) + (define rectify-0-∇ (λ (s z) (cond ((< s 0.0) 0.0) (else z)))) +(define rectify-0-∇-acc + (λ (s z) + "step(0, @{s}) * @{z}")) + (define rectify-shape (λ (s) s)) (define rectify-0 - (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + (prim1 rectify-0-ρ rectify-0-ρ-acc rectify-0-∇ rectify-0-∇-acc rectify-shape)) ;;------------------------------------ ;; differentiable extended functions. @@ -102,20 +177,17 @@ ;; non-differentiable extended functions. ;;------------------------------------ -(define *-ρ (ext2-ρ * 0 0)) -(define +-ρ (ext2-ρ + 0 0)) -(define --ρ (ext2-ρ - 0 0)) -(define /-ρ (ext2-ρ / 0 0)) -(define expt-ρ (ext2-ρ expt 0 0)) - -(define exp-ρ (ext1-ρ exp 0)) -(define log-ρ (ext1-ρ log 0)) -(define abs-ρ (ext1-ρ abs-0-ρ 0)) -(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) - -(define sqrt-ρ - (λ (a) - (expt-ρ a 1/2))) +(define *-ρ (ext2-ρ * *-0-0-ρ-acc 0 0)) +(define +-ρ (ext2-ρ + +-0-0-ρ-acc 0 0)) +(define --ρ (ext2-ρ - --0-0-ρ-acc 0 0)) +(define /-ρ (ext2-ρ / /-0-0-ρ-acc 0 0)) +(define expt-ρ (ext2-ρ expt expt-0-0-ρ-acc 0 0)) + +(define exp-ρ (ext1-ρ exp exp-0-ρ-acc 0)) +(define log-ρ (ext1-ρ log log-0-ρ-acc 0)) +(define abs-ρ (ext1-ρ abs-0-ρ abs-0-ρ-acc 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ rectify-0-ρ-acc 0)) +(define sqrt-ρ (ext1-ρ sqrt sqrt-0-ρ-acc 0)) (define sqr-ρ (λ (x) diff --git a/accelerated-tensors/ext-ops/B-comparators.rkt b/accelerated-tensors/ext-ops/B-comparators.rkt index c42a2cf..3db8e0f 100644 --- a/accelerated-tensors/ext-ops/B-comparators.rkt +++ b/accelerated-tensors/ext-ops/B-comparators.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../autodiff.rkt") ;;---------------------------- @@ -37,6 +38,11 @@ ((f (ρ da) (ρ db)) 1.0) (else 0.0))))) +(define comparator-ρ-acc + (λ (f) + (λ (a b) + "@{a} @{f} @{b}"))) + (define comparator-∇ (λ (f) (λ (da db z) @@ -44,40 +50,48 @@ ((f (ρ da) (ρ db)) (values z z)) (else (values 0.0 0.0)))))) +(define comparator-∇-acc + (λ (f) + (λ (a b z) + (let ((bool "@{a} @{f} @{b}")) + (values "@{bool}*@{z}" "@{bool}*@{z}"))))) + (define comparator-shape (λ (f) (λ (sa sb) sa))) (define comparator-prim - (λ (f) - (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + (λ (f f-acc) + (prim2 (comparator-ρ f) (comparator-ρ-acc f-acc) + (comparator-∇ f) (comparator-∇-acc f-acc) + (comparator-shape f)))) (define extended-comparator - (λ (f) - (ext2 (comparator-prim f) 0 0))) + (λ (f f-acc) + (ext2 (comparator-prim f f-acc) 0 0))) (define =-1 - (extended-comparator =)) + (extended-comparator = "==")) (define <-1 - (extended-comparator <)) + (extended-comparator < "<")) (define >-1 - (extended-comparator >)) + (extended-comparator > ">")) (define <=-1 - (extended-comparator <=)) + (extended-comparator <= "<=")) (define >=-1 - (extended-comparator >=)) + (extended-comparator >= ">=")) (define != (λ (a b) (not (= a b)))) (define !=-1 - (extended-comparator !=)) + (extended-comparator != "!=")) (include "test/test-B-comparators.rkt") diff --git a/accelerated-tensors/ext-ops/C-star-2-1.rkt b/accelerated-tensors/ext-ops/C-star-2-1.rkt index 5eb0d63..8cb8ab8 100644 --- a/accelerated-tensors/ext-ops/C-star-2-1.rkt +++ b/accelerated-tensors/ext-ops/C-star-2-1.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../tensors/0-vectors.rkt") (require (only-in "../tensors.rkt" ext2-ρ)) (require "../autodiff.rkt") @@ -13,6 +14,17 @@ (* (vref v0 (+ i0 i)) (vref v1 (+ i1 (modulo i stride1)))))))) +(define *-2-1-base-ρ-acc + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + #< v max) (values v (+ (- i i0) 0.0))) (else (values max max-i)))))))) +(define argmax-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< max) { + max = v; + max_i = i - @{i0} + 0.0; + } + } + @{v-out}[@{i-out}] = max_i; +EOF + )) + (define argmax-1-∇ (λ (g0 v0 i0 stride0 vz iz stride-z) @@ -23,18 +41,28 @@ (for ([i (in-range i0 (+ i0 stride0))]) (vset! g0 i 0.0))))) +(define argmax-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< v max) v) (else max))))))) +(define max-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< v max) (values v (- i i0))) (else (values max max-i)))))))) +(define max-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< max) { + max = v; + max_i = i - @{i0}; + } + } + for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { + if(i == @{i0}+max_i) { + @{g0}[i] = z; + } else { + @{g0}[i] = 0.0; + } + } +EOF + )) + (define max-shape (λ (st) (cdr st))) (define max-1 - (prim1 max-1-ρ max-1-∇ max-shape)) + (prim1 max-1-ρ max-1-ρ-acc max-1-∇ max-1-∇-acc max-shape)) (define d-max (ext1 max-1 1)) (define max-ρ - (ext1-ρ max-1-ρ 1 max-shape)) + (ext1-ρ max-1-ρ max-1-ρ-acc 1 max-shape)) (include "test/test-F-max.rkt") diff --git a/accelerated-tensors/ext-ops/G-correlate.rkt b/accelerated-tensors/ext-ops/G-correlate.rkt index 9db2109..1e2ba84 100644 --- a/accelerated-tensors/ext-ops/G-correlate.rkt +++ b/accelerated-tensors/ext-ops/G-correlate.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../tensors/0-vectors.rkt") (require (only-in "../tensors.rkt" ext2-ρ len)) (require "../autodiff.rkt") @@ -29,6 +30,28 @@ (+ sum (* a b)))) (else sum)))))))))) +(define correlate-3-1-ρ-acc + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + #<= i1_min && bi < i1_max) { + sum += @{v0}[ai] * @{v1}[bi]; + } + } + @{v-out}[@{i-out}+i] = sum; + } +EOF + ))) + (define correlate-3-1-∇ (λ (nd md qd) (λ (g0 g1 @@ -50,6 +73,44 @@ (vset! g1 bi (+ (vref g1 bi) (* z a))))))))))))) +(define correlate-3-1-∇-acc + (λ (nd md qd) + (λ (g + v0 i0 bmd + v1 i1 d + vz iz b) + (values + #<= i1_min && bi < i1_max) { + @{g}[ai] += z * @{v1}[bi]; + } + } + } +EOF + + #<= i1_min && bi < i1_max) { + @{g}[bi] += z * @{v0}[ai]; + } + } + } +EOF + )))) + (define correlate-shape (λ (bmd nd) (list (car bmd)))) @@ -58,7 +119,9 @@ (λ (nd md qd) (prim2 (correlate-3-1-ρ nd md qd) + (correlate-3-1-ρ-acc nd md qd) (correlate-3-1-∇ nd md qd) + (correlate-3-1-∇-acc nd md qd) correlate-shape))) (define d-correlate @@ -83,7 +146,7 @@ (q (/ (- m 1) 2)) ;; This is the padding. (qd (* q d)) (md (* m d))) - ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + ((ext2-ρ (correlate-3-1-ρ nd md qd) (correlate-3-1-ρ-acc nd md qd) 3 1 correlate-shape) bank signal)))) (define last diff --git a/accelerated-tensors/ext-ops/I-flatten.rkt b/accelerated-tensors/ext-ops/I-flatten.rkt index bf24773..0ef02f6 100644 --- a/accelerated-tensors/ext-ops/I-flatten.rkt +++ b/accelerated-tensors/ext-ops/I-flatten.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) (require (only-in "../autodiff.rkt" prim1 ext1)) @@ -7,10 +8,30 @@ (λ (t) (reshape (flatten-shape (shape t)) t))) +(define flatten-2-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<vector cv) (build-vector (cvector-length cv) @@ -17,14 +19,14 @@ (define (with-opencl th) (let* ([platform (cvector-ref (clGetPlatformIDs:vector) 0)] [devices (clGetDeviceIDs:vector platform 'CL_DEVICE_TYPE_GPU)] - [device-idx 0] - [device (cvector-ref devices device-idx)]) + [device-idx 0]) (parameterize* ([context #f] - [command-queue #f]) + [command-queue #f] + [device (cvector-ref devices device-idx)]) (dynamic-wind (λ () (context (clCreateContext #f (cvector->vector devices))) - (command-queue (clCreateCommandQueue (context) device '()))) + (command-queue (clCreateCommandQueue (context) (device) '()))) th (λ () (when (command-queue) @@ -32,6 +34,18 @@ (when (context) (clReleaseContext (context)))))))) +(define print-cl-build-log + (λ (program _) + (when (debug-kernel?) + (printf "Program Source:~n~a~n" + (clGetProgramInfo:generic program 'CL_PROGRAM_SOURCE)) + (printf "Build status:~a~n" + (clGetProgramBuildInfo:generic program (device) + 'CL_PROGRAM_BUILD_STATUS)) + (printf "Build log:~a~n" + (clGetProgramBuildInfo:generic program (device) + 'CL_PROGRAM_BUILD_LOG))))) + (define (binary-expr rator rand1 rand2) (string-append "(" rand1 " " rator " " rand2 ")")) @@ -146,7 +160,13 @@ EOF 1 (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) + (clBuildProgram program (vector (device)) (make-bytes 0) + ;; This extra argument works only because Darshal + ;; uses a modified version of the opencl/c library + ;; which makes the clBuildProgram function accept an + ;; additional callback argument for debugging just + ;; like the original C API. + print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf0) (clSetKernelArg:_cl_int kernel 1 stride0) @@ -175,7 +195,7 @@ EOF (λ (f-acc base-shape out-shape) (unless (and (null? base-shape) (null? out-shape)) (error 'ρ1-functional-non-scalar-acc - (string-append "Functional primitives can only accept and" + (string-append "Accelerated functional primitives can only accept and" " return scalars, so try defining a" " preallocated primitive instead." " Input and output shape found: ~a ~a") @@ -201,7 +221,7 @@ __kernel void Kernel (__global float* g0, int i0 = 0 + (iz / stridez) * stride0; @{(prim1-∇-f "g0" "v0" "i0" "stride0" - "vz" "iz" "stride-z")} + "vz" "iz" "stridez")} } EOF ) @@ -233,13 +253,13 @@ EOF (* (ctype-sizeof _cl_float) size0) #f)) - (set! program (clCreateProgramWithSource (context) (make-vector 1 (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) + (clBuildProgram program (vector (device)) (make-bytes 0) + print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf-g) (clSetKernelArg:_cl_mem kernel 1 buf0) @@ -271,7 +291,7 @@ EOF (λ (f-acc base-shape out-shape) (unless (and (null? base-shape) (null? out-shape)) (error '∇1-functional-non-scalar-acc - (string-append "Functional primitives can only accept and" + (string-append "Accelerated functional primitives can only accept and" " return scalars, so try defining a" " preallocated primitive instead." " Input and output shape found: ~a ~a") @@ -340,7 +360,8 @@ EOF (make-vector 1 (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) + (clBuildProgram program (vector (device)) (make-bytes 0) + print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf0) (clSetKernelArg:_cl_int kernel 1 stride0) @@ -373,7 +394,7 @@ EOF (λ (f-acc t-shape u-shape out-shape) (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) (error 'ρ2-functional-non-scalar-acc - (string-append "Functional primitives can only accept and" + (string-append "Accelerated functional primitives can only accept and" " return scalars, so try defining a" " preallocated primitive instead." " Input 1, input 2 and output shape found: ~a ~a ~a") @@ -386,98 +407,70 @@ EOF EOF )))) -(define (ext2-∇-kernel-atomic prim2-∇-f strides) +(define (ext2-∇-kernel prim2-∇-f strides + s0 s1 r0 r1 s-out r-out) (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" "v0" "i0" "stride0" "v1" "i1" "stride1" "vz" "iz" "stride_z")) + ((repeats0 repeats1) (calc-repeats s0 s1 r0 r1 s-out r-out)) ((generate-idxs) (idx-exprs strides 0 0)) - ((i0-expr i1-expr) (generate-idxs "iz"))) + ((generate-idxs-inv) (idx-exprs-inv strides 0 + repeats0 repeats1 s-out)) + ((i0-expr i1-expr) (generate-idxs "iz")) + ((iz-expr0 iz-expr1) (generate-idxs-inv "i0" "i1" "i_rep"))) #<bytes/utf-8 kernel-code)))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) + (clBuildProgram program (vector (device)) (make-bytes 0) + print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf-g0) (clSetKernelArg:_cl_mem kernel 1 buf-g1) (clSetKernelArg:_cl_mem kernel 2 buf0) (clSetKernelArg:_cl_int kernel 3 stride0) - (clSetKernelArg:_cl_mem kernel 4 buf1) - (clSetKernelArg:_cl_int kernel 5 stride1) - (clSetKernelArg:_cl_mem kernel 6 buf-z) - (clSetKernelArg:_cl_int kernel 7 stride-z)) + (clSetKernelArg:_cl_int kernel 4 size0) + (clSetKernelArg:_cl_mem kernel 5 buf1) + (clSetKernelArg:_cl_int kernel 6 stride1) + (clSetKernelArg:_cl_int kernel 7 size1) + (clSetKernelArg:_cl_mem kernel 8 buf-z) + (clSetKernelArg:_cl_int kernel 9 stride-z)) (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-z stride-z)) + (make-vector 1 global-work-size) (make-vector 0) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 @@ -552,84 +548,11 @@ EOF (when buf0 (clReleaseMemObject buf0)))))))) -(define (run-prim2-∇-split! kernel-code0 kernel-code1 g0 g1 - v0 off0 size0 stride0 - v1 off1 size1 stride1 - vz offz size-z stride-z) - (with-opencl - (λ () - (define (run! kernel-code g size-in stride-in) - (let* ([buf0 #f] - [buf1 #f] - [buf-z #f] - [buf-g #f] - [program #f] - [kernel #f] - [event #f]) - (dynamic-wind - (λ () - ;; Exclude memory consumed by elements before offset of input vector v0 - (set! buf0 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size0) - (vref-cpointer v0 off0))) - (set! buf1 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size1) - (vref-cpointer v1 off1))) - (set! buf-z (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size-z) - (vref-cpointer vz offz))) - (set! buf-g (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY - (* (ctype-sizeof _cl_float) - size-in) - #f)) - (set! program (clCreateProgramWithSource - (context) - (make-vector 1 (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (make-vector 0) (make-bytes 0)) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf-g) - (clSetKernelArg:_cl_mem kernel 1 buf0) - (clSetKernelArg:_cl_int kernel 2 stride0) - (clSetKernelArg:_cl_mem kernel 3 buf1) - (clSetKernelArg:_cl_int kernel 4 stride1) - (clSetKernelArg:_cl_mem kernel 5 buf-z) - (clSetKernelArg:_cl_int kernel 6 stride-z)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-in stride-in)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size-in) - (vec->cpointer g) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-g - (clReleaseMemObject buf-g)) - (when buf-z - (clReleaseMemObject buf-z)) - (when buf1 - (clReleaseMemObject buf1)) - (when buf0 - (clReleaseMemObject buf0)))))) - (run! kernel-code0 g0 size0 stride0) - (run! kernel-code1 g1 size1 stride1)))) - (define functional->preallocated-2-∇-acc (λ (f-acc t-shape u-shape out-shape) (unless (and (null? t-shape) (null? u-shape) (null? out-shape)) (error '∇2-functional-non-scalar-acc - (string-append "Functional primitives can only accept and" + (string-append "Accelerated functional primitives can only accept and" " return scalars, so try defining a" " preallocated primitive instead." " Input 1, input 2 and output shape found: ~a ~a ~a") @@ -652,6 +575,4 @@ EOF (provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel - run-prim2-∇-atomic! run-prim2-∇-split! - functional->preallocated-2-∇-acc - ext2-∇-kernel-atomic ext2-∇-kernel-split) + run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel) diff --git a/accelerated-tensors/tensors/A-equality.rkt b/accelerated-tensors/tensors/A-equality.rkt index 639533f..ae320c6 100644 --- a/accelerated-tensors/tensors/A-equality.rkt +++ b/accelerated-tensors/tensors/A-equality.rkt @@ -59,7 +59,12 @@ (vref expected-store i-expected)) check) (else (return #f))))))))) -(define-binary-check (check-tensor-equal? tensor-equal? actual expected)) +(define-check (check-tensor-equal? actual expected) + (unless (tensor-equal? actual expected) + (fail-check (format "Tensors failed to match.~%actual:~%~s~%expected:~s~%~%actual store:~%~s~%expected store:~s~%" + actual expected + (with-output-to-string (λ () (print-vec (flat-store actual)))) + (with-output-to-string (λ () (print-vec (flat-store expected)))))))) (include "test/test-A-equality.rkt") diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index 3421a75..4613e30 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -20,7 +20,7 @@ (λ (t) (cond ((number? t) (f t)) - ((expects-preallocated? f) + ((expects-preallocated? f-acc) (scalarize (flat-ext1-ρ f f-acc m shape-fn t))) (else @@ -37,7 +37,7 @@ (λ (t z) (cond ((number? t) (f t z)) - ((expects-preallocated? f) + ((expects-preallocated? f-acc) (scalarize (flat-ext1-∇ f f-acc m shape-fn t (ensure-flat z)))) (else (let* ((in-shape (flat-shape t)) @@ -67,24 +67,14 @@ (cond ((null? out-shape) (vset! v-out i-out a)) (else - (error 'ρ-functional-non-scalar-out - (string-append "Functional primitives can only return scalars," - " so try defining a preallocated primitive" - " instead. Out shape found: ~a") - out-shape) - #;(v-copy-flat! v-out i-out a))))) + (v-copy-flat! v-out i-out a))))) (define set-prealloc-∇! (λ (v-out i-out out-shape a) (cond ((null? out-shape) (vset! v-out i-out (+ (vref v-out i-out) a))) (else - (error '∇-functional-non-scalar-out - (string-append "Functional primitives can only return scalars," - " so try defining a preallocated primitive" - " instead. Out shape found: ~a") - out-shape) - #;(v-add-flat! v-out i-out a))))) + (v-add-flat! v-out i-out a))))) (define arg-value (λ (v-shape v i) @@ -114,7 +104,7 @@ (λ (t u) (cond ((and (number? t) (number? u)) (f t u)) - ((expects-preallocated? f) + ((expects-preallocated? f-acc) (scalarize (flat-ext2-ρ f f-acc m n shape-fn t u))) ((number? t) @@ -151,7 +141,7 @@ (values (scalarize da) (scalarize db)))))) (cond ((and (number? t) (number? u)) (f t u z)) - ((expects-preallocated? f) + ((expects-preallocated? f-acc) (invoke-flat-ext2-∇ f f-acc m n shape-fn t u z)) ((number? t) (let* ((t-shape '()) @@ -320,20 +310,13 @@ (g1 (new-vec (size-of s1) 0.0))) (cond ((accelerate?) - (cond - (parallel-desc? (run-prim2-∇-atomic! (ext2-∇-kernel-atomic fᵈ-acc strides) - g0 g1 - v0 off0 size0 stride0 - v1 off1 size1 stride1 - vz offz size-z stride-z)) - (else - (let*-values (((kernel-code0 kernel-code1) - (ext2-∇-kernel-split fᵈ-acc strides s0 s1 r0 r1 sz (length sf-z)))) - (run-prim2-∇-split! kernel-code0 kernel-code1 - g0 g1 - v0 off0 size0 stride0 - v1 off1 size1 stride1 - vz offz size-z stride-z))))) + (let ((kernel-code (ext2-∇-kernel fᵈ-acc strides s0 s1 r0 r1 sz + (length sf-z)))) + (run-prim2-∇! kernel-code + g0 g1 + v0 off0 size0 stride0 + v1 off1 size1 stride1 + vz offz size-z stride-z))) (else (for ([iz (in-range 0 size-z stride-z)]) (let-values (((i0 i1) (idxs strides iz off0 off1))) @@ -462,5 +445,7 @@ (provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? functional->preallocated-1-ρ functional->preallocated-1-∇ functional->preallocated-2-ρ functional->preallocated-2-∇ + functional->preallocated-1-ρ-acc functional->preallocated-1-∇-acc + functional->preallocated-2-ρ-acc functional->preallocated-2-∇-acc merge-shapes min-shape ext2-shapes idxs flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/impl-loader.rkt b/impl-loader.rkt index 7c3dc7d..c2f5125 100644 --- a/impl-loader.rkt +++ b/impl-loader.rkt @@ -36,14 +36,19 @@ (λ () (car (dict-ref (settings) 'accelerate?)))) +(define debug-kernel? + (λ () + (car (dict-ref (settings) 'debug-kernel?)))) + ;; Default settings (define default-preferences `((tensor-implementation learner) - (accelerate? #t))) + (accelerate? #t) + (debug-kernel? #f))) (when (not (settings)) (init-settings) (println "settings=") (pretty-print (settings))) -(provide tensor-implementation accelerate?) +(provide tensor-implementation accelerate? debug-kernel?) diff --git a/malted/test/test-A-core.rkt b/malted/test/test-A-core.rkt index 22aad31..596ebde 100644 --- a/malted/test/test-A-core.rkt +++ b/malted/test/test-A-core.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require "../base.rkt") (let ((a 7) (b 13)) @@ -39,9 +40,11 @@ (tensor -0.04142 -0.03111)))) (let ((a (tensor 7 8 9))) - (check-dual-equal? (exp a) (tensor 1096.6332 2980.9579 8103.0839)) - (check-dual-equal? ((∇¹ exp) a) - (list (tensor 1096.6332 2980.9579 8103.0839))) + ;; Lower tolerance because the openCL implementation of exp gives a slightly different answer + (parameterize ((tolerance 0.001)) + (check-dual-equal? (exp a) (tensor 1096.6332 2980.9579 8103.0839)) + (check-dual-equal? ((∇¹ exp) a) + (list (tensor 1096.6332 2980.9579 8103.0839)))) (check-dual-equal? (log a) (tensor 1.9459 2.0794 2.1972)) (check-dual-equal? ((∇¹ log) a) (list (tensor 0.1428 0.125 0.1111))) (check-dual-equal? (sqrt a) (tensor 2.6457 2.8284 3.0)) From 6b57ad861e1891352ed9c6b1af7fc3b6ab2baed7 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 20 Apr 2024 11:09:27 -0400 Subject: [PATCH 39/83] [add-acc]Debug multiple opencl context references --- accelerated-tensors/tensors/2-acc-runtime.rkt | 76 ++++++++++++++----- .../tensors/test/test-2-acc-runtime.rkt | 9 +++ 2 files changed, 68 insertions(+), 17 deletions(-) create mode 100644 accelerated-tensors/tensors/test/test-2-acc-runtime.rkt diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index d3b38ab..0e9b69a 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -10,29 +10,61 @@ (define context (make-parameter #f)) (define command-queue (make-parameter #f)) -(define device (make-parameter #f)) +(define platform + (let ([platform #f]) + (lambda () + (or platform + (begin + (set! platform (cvector-ref (clGetPlatformIDs:vector) 0)) + platform))))) +(define devices + (let ([devices #f]) + (lambda () + (or devices + (begin + (set! devices (clGetDeviceIDs:vector (platform) 'CL_DEVICE_TYPE_GPU)) + devices))))) +(define device + (let ([device #f]) + (lambda () + (or device + (begin + (set! device (cvector-ref (devices) 0)) + device))))) (define (cvector->vector cv) (build-vector (cvector-length cv) (curry cvector-ref cv))) (define (with-opencl th) - (let* ([platform (cvector-ref (clGetPlatformIDs:vector) 0)] - [devices (clGetDeviceIDs:vector platform 'CL_DEVICE_TYPE_GPU)] - [device-idx 0]) - (parameterize* ([context #f] - [command-queue #f] - [device (cvector-ref devices device-idx)]) - (dynamic-wind - (λ () - (context (clCreateContext #f (cvector->vector devices))) - (command-queue (clCreateCommandQueue (context) (device) '()))) - th - (λ () - (when (command-queue) - (clReleaseCommandQueue (command-queue))) - (when (context) - (clReleaseContext (context)))))))) + (dynamic-wind + (λ () + (unless (context) + (context (clCreateContext #f (cvector->vector (devices)))) + (when (debug-kernel?) + (printf "Context reference count after creation: ~a~n" + (clGetContextInfo:generic (context) 'CL_CONTEXT_REFERENCE_COUNT)))) + (unless (command-queue) + (command-queue (clCreateCommandQueue (context) (device) '())) + (when (debug-kernel?) + (printf "CommandQueue reference count after creation: ~a~n" + (clGetCommandQueueInfo:generic (command-queue) + 'CL_QUEUE_REFERENCE_COUNT))))) + th + (λ () + (when (command-queue) + (when (debug-kernel?) + (printf "CommandQueue reference count before release: ~a~n" + (clGetCommandQueueInfo:generic (command-queue) + 'CL_QUEUE_REFERENCE_COUNT))) + (clReleaseCommandQueue (command-queue)) + (command-queue #f)) + (when (context) + (when (debug-kernel?) + (printf "Context reference count before release: ~a~n" + (clGetContextInfo:generic (context) 'CL_CONTEXT_REFERENCE_COUNT))) + (clReleaseContext (context)) + (context #f))))) (define print-cl-build-log (λ (program _) @@ -137,6 +169,8 @@ EOF (define (run-prim1-ρ! kernel-code v0 off0 size0 stride0 v-out size-out stride-out) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () (let* ([buf0 #f] @@ -229,6 +263,8 @@ EOF (define (run-prim1-∇! kernel-code g0 v0 off0 size0 stride0 vz offz size-z stride-z) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () (let* ([buf0 #f] @@ -331,6 +367,8 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 v-out size-out stride-out) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () (let* ([buf0 #f] @@ -466,6 +504,8 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 vz offz size-z stride-z) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () (let* ([global-work-size (max (/ size0 stride0) @@ -572,6 +612,8 @@ EOF EOF )))))) +(include "test/test-2-acc-runtime.rkt") + (provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel diff --git a/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt new file mode 100644 index 0000000..e15c672 --- /dev/null +++ b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt @@ -0,0 +1,9 @@ +(module+ test + (require rackunit) + + (for ((_ (in-range 100))) + (with-opencl + (λ () + (check-true (not (not (context)))) + (check-true (not (not (command-queue))))))) + ) From f3f2a1de2ec040f6f0ba617c80936c69406122ab Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 20 Apr 2024 11:10:00 -0400 Subject: [PATCH 40/83] =?UTF-8?q?[add-acc]Fix=20F-max=20=E2=88=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- accelerated-tensors/ext-ops/F-max.rkt | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/accelerated-tensors/ext-ops/F-max.rkt b/accelerated-tensors/ext-ops/F-max.rkt index 78c6529..d78b904 100644 --- a/accelerated-tensors/ext-ops/F-max.rkt +++ b/accelerated-tensors/ext-ops/F-max.rkt @@ -35,9 +35,8 @@ EOF (for/fold ([max -inf.0] [max-i -1] #:result (for ([i (in-range i0 (+ i0 stride0))]) - (cond - ((= i (+ i0 max-i)) (vset! g0 i z)) - (else (vset! g0 i 0.0))))) + (when (= i (+ i0 max-i)) + (vset! g0 i (+ (vref g0 i) z))))) ([i (in-range i0 (+ i0 stride0))]) (let ((v (vref v0 i))) (cond @@ -60,9 +59,9 @@ EOF } for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { if(i == @{i0}+max_i) { - @{g0}[i] = z; + @{g0}[i] += z; } else { - @{g0}[i] = 0.0; + @{g0}[i] += 0.0; } } EOF From 2ae5f4b6fe625c3168a8362a950f4a2757560af9 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sun, 28 Apr 2024 10:24:22 -0400 Subject: [PATCH 41/83] [add-acc]add debug code for run* and A-scalar-ops --- .../ext-ops/test/test-A-scalar-ops.rkt | 78 +++++++++++-------- accelerated-tensors/tensors/2-acc-runtime.rkt | 30 +++---- 2 files changed, 61 insertions(+), 47 deletions(-) diff --git a/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt index 2c13e39..c842693 100644 --- a/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt +++ b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt @@ -2,7 +2,10 @@ (require rackunit) (require (only-in "../tensors.rkt" tensor)) + (for ((i (in-range 100))) + (printf "### Iteration ~a~n." i) ;; Check basic numericals + #; (let ((a 2) (b 3)) (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) @@ -38,41 +41,45 @@ (let ((a (tensor 2.0 3.0 4.0)) (b (tensor 3.0 8.0 9.0))) - (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) - (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) - (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) - (check-ρ-∇ (d/ a b) - (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) - (list (tensor 0.3333333333333333 0.125 0.1111111111111111) - (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) - (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) - (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) - (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) - (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) - (check-ρ-∇ (d-expt a b) - (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) - (list (tensor 12.0 17496.0 589824.0) - (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) - - (check-ρ-∇ (d-sqrt a) - (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) - (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) - - (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) - - (define z - (lambda (x y) - (d+ (d-log x) (d* x y)))) - - (let ((c (dual* (tensor 2.0 3.0 4.0))) - (d (dual* (tensor 2.0 3.0 4.0)))) - (check-dual-equal? - ((∇¹ z) c d) - (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) - - (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + (check-dual-equal? (d+ a b) (tensor 5.0 11.0 13.0)) + ;(check-dual-equal? ((∇¹ d+) a b) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + + ;(check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + ;(check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + ;(check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + ;(check-ρ-∇ (d/ a b) + ; (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + ; (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + ; (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + ;(check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + ; (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + ;(check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + ; (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + ;(check-ρ-∇ (d-expt a b) + ; (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + ; (list (tensor 12.0 17496.0 589824.0) + ; (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + ;(check-ρ-∇ (d-sqrt a) + ; (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + ; (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + ;(check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + ;(define z + ; (lambda (x y) + ; (d+ (d-log x) (d* x y)))) + + ;(let ((c (dual* (tensor 2.0 3.0 4.0))) + ; (d (dual* (tensor 2.0 3.0 4.0)))) + ; (check-dual-equal? + ; ((∇¹ z) c d) + ; (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + #;(check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) ;; Check numericals with lists + #; (let ((x (dual* 3)) (y (dual* 2))) (let ((f d*)) @@ -108,6 +115,7 @@ x y) '(4.0 6.0)))) + #| (let ((a 7) (b (tensor 13))) (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) @@ -119,4 +127,6 @@ (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) - (list 0.14358 (tensor -0.04142 -0.03111))))) + (list 0.14358 (tensor -0.04142 -0.03111)))) +|# + )) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index 0e9b69a..1a1c244 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -36,9 +36,12 @@ (build-vector (cvector-length cv) (curry cvector-ref cv))) +(define in-opencl (make-parameter 0)) (define (with-opencl th) (dynamic-wind (λ () + (in-opencl (add1 (in-opencl))) + (printf "###Nesting level: ~a~n" (in-opencl)) (unless (context) (context (clCreateContext #f (cvector->vector (devices)))) (when (debug-kernel?) @@ -64,7 +67,8 @@ (printf "Context reference count before release: ~a~n" (clGetContextInfo:generic (context) 'CL_CONTEXT_REFERENCE_COUNT))) (clReleaseContext (context)) - (context #f))))) + (context #f)) + (in-opencl (sub1 (in-opencl)))))) (define print-cl-build-log (λ (program _) @@ -169,10 +173,10 @@ EOF (define (run-prim1-ρ! kernel-code v0 off0 size0 stride0 v-out size-out stride-out) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([buf0 #f] [buf-out #f] [program #f] @@ -200,7 +204,7 @@ EOF ;; which makes the clBuildProgram function accept an ;; additional callback argument for debugging just ;; like the original C API. - print-cl-build-log) + #;print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf0) (clSetKernelArg:_cl_int kernel 1 stride0) @@ -263,10 +267,10 @@ EOF (define (run-prim1-∇! kernel-code g0 v0 off0 size0 stride0 vz offz size-z stride-z) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([buf0 #f] [buf-z #f] [buf-g #f] @@ -295,7 +299,7 @@ EOF (string->bytes/utf-8 kernel-code)))) (clBuildProgram program (vector (device)) (make-bytes 0) - print-cl-build-log) + #;print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf-g) (clSetKernelArg:_cl_mem kernel 1 buf0) @@ -367,10 +371,10 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 v-out size-out stride-out) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([buf0 #f] [buf1 #f] [buf-out #f] @@ -399,7 +403,7 @@ EOF 1 (string->bytes/utf-8 kernel-code)))) (clBuildProgram program (vector (device)) (make-bytes 0) - print-cl-build-log) + #;print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf0) (clSetKernelArg:_cl_int kernel 1 stride0) @@ -504,10 +508,10 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 vz offz size-z stride-z) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) (with-opencl (λ () + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([global-work-size (max (/ size0 stride0) (/ size1 stride1))] [buf0 #f] @@ -547,7 +551,7 @@ EOF (context) (make-vector 1 (string->bytes/utf-8 kernel-code)))) (clBuildProgram program (vector (device)) (make-bytes 0) - print-cl-build-log) + #;print-cl-build-log) (set! kernel (clCreateKernel program #"Kernel")) (clSetKernelArg:_cl_mem kernel 0 buf-g0) (clSetKernelArg:_cl_mem kernel 1 buf-g1) From 4fd229098f37ee93a6652a21a467a9614d411081 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 22 May 2024 15:34:53 -0400 Subject: [PATCH 42/83] [add-acc]Fix gpu-related intermittent bugs and cleanup code --- .../ext-ops/test/test-A-scalar-ops.rkt | 78 ++-- accelerated-tensors/tensors/0-vectors.rkt | 4 +- accelerated-tensors/tensors/2-acc-runtime.rkt | 403 +++++++++--------- accelerated-tensors/tensors/D-extend.rkt | 14 +- .../tensors/test/test-2-acc-runtime.rkt | 7 +- impl-loader.rkt | 4 +- 6 files changed, 244 insertions(+), 266 deletions(-) diff --git a/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt index c842693..2c13e39 100644 --- a/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt +++ b/accelerated-tensors/ext-ops/test/test-A-scalar-ops.rkt @@ -2,10 +2,7 @@ (require rackunit) (require (only-in "../tensors.rkt" tensor)) - (for ((i (in-range 100))) - (printf "### Iteration ~a~n." i) ;; Check basic numericals - #; (let ((a 2) (b 3)) (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) @@ -41,45 +38,41 @@ (let ((a (tensor 2.0 3.0 4.0)) (b (tensor 3.0 8.0 9.0))) - (check-dual-equal? (d+ a b) (tensor 5.0 11.0 13.0)) - ;(check-dual-equal? ((∇¹ d+) a b) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) - - ;(check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) - ;(check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) - ;(check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) - ;(check-ρ-∇ (d/ a b) - ; (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) - ; (list (tensor 0.3333333333333333 0.125 0.1111111111111111) - ; (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) - ;(check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) - ; (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) - ;(check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) - ; (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) - ;(check-ρ-∇ (d-expt a b) - ; (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) - ; (list (tensor 12.0 17496.0 589824.0) - ; (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) - - ;(check-ρ-∇ (d-sqrt a) - ; (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) - ; (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) - - ;(check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) - - ;(define z - ; (lambda (x y) - ; (d+ (d-log x) (d* x y)))) - - ;(let ((c (dual* (tensor 2.0 3.0 4.0))) - ; (d (dual* (tensor 2.0 3.0 4.0)))) - ; (check-dual-equal? - ; ((∇¹ z) c d) - ; (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) - - #;(check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + (check-ρ-∇ (d/ a b) + (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + (check-ρ-∇ (d-expt a b) + (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + (list (tensor 12.0 17496.0 589824.0) + (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + (check-ρ-∇ (d-sqrt a) + (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* (tensor 2.0 3.0 4.0))) + (d (dual* (tensor 2.0 3.0 4.0)))) + (check-dual-equal? + ((∇¹ z) c d) + (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) ;; Check numericals with lists - #; (let ((x (dual* 3)) (y (dual* 2))) (let ((f d*)) @@ -115,7 +108,6 @@ x y) '(4.0 6.0)))) - #| (let ((a 7) (b (tensor 13))) (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) @@ -127,6 +119,4 @@ (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) - (list 0.14358 (tensor -0.04142 -0.03111)))) -|# - )) + (list 0.14358 (tensor -0.04142 -0.03111))))) diff --git a/accelerated-tensors/tensors/0-vectors.rkt b/accelerated-tensors/tensors/0-vectors.rkt index 8ef4d37..8135743 100644 --- a/accelerated-tensors/tensors/0-vectors.rkt +++ b/accelerated-tensors/tensors/0-vectors.rkt @@ -19,10 +19,10 @@ (define vec->cpointer f32vector->cpointer) (define vref-cpointer (λ (v i) - (unless (< i (vlen v)) + (unless (and (<= 0 i) (< i (vlen v))) (error 'vref-cpointer "Index ~a out of range [0, ~a]" - i (vlen v))) + i (sub1 (vlen v)))) (ptr-add (vec->cpointer v) i _float))) (define-for-syntax debug-leaks? #f) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index 1a1c244..a94bc5f 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -7,9 +7,32 @@ "0-vectors.rkt" "../../impl-loader.rkt") +;; TODO: Cache compiled kernels based on a unique prim name + +(define context + (let ([context #f]) + (λ () + (or context + (begin + (set! context (clCreateContext #f (cvector->vector (devices)))) + context))))) +(define command-queue + (let ([command-queue #f]) + (λ () + (or command-queue + (begin + (set! command-queue (clCreateCommandQueue (context) (device) '())) + command-queue))))) + +(define old-exit-handler (exit-handler)) +(exit-handler + (λ (v) + (when (command-queue) + (clReleaseCommandQueue (command-queue))) + (when (context) + (clReleaseContext (context))) + (old-exit-handler v))) -(define context (make-parameter #f)) -(define command-queue (make-parameter #f)) (define platform (let ([platform #f]) (lambda () @@ -36,40 +59,8 @@ (build-vector (cvector-length cv) (curry cvector-ref cv))) -(define in-opencl (make-parameter 0)) -(define (with-opencl th) - (dynamic-wind - (λ () - (in-opencl (add1 (in-opencl))) - (printf "###Nesting level: ~a~n" (in-opencl)) - (unless (context) - (context (clCreateContext #f (cvector->vector (devices)))) - (when (debug-kernel?) - (printf "Context reference count after creation: ~a~n" - (clGetContextInfo:generic (context) 'CL_CONTEXT_REFERENCE_COUNT)))) - (unless (command-queue) - (command-queue (clCreateCommandQueue (context) (device) '())) - (when (debug-kernel?) - (printf "CommandQueue reference count after creation: ~a~n" - (clGetCommandQueueInfo:generic (command-queue) - 'CL_QUEUE_REFERENCE_COUNT))))) - th - (λ () - (when (command-queue) - (when (debug-kernel?) - (printf "CommandQueue reference count before release: ~a~n" - (clGetCommandQueueInfo:generic (command-queue) - 'CL_QUEUE_REFERENCE_COUNT))) - (clReleaseCommandQueue (command-queue)) - (command-queue #f)) - (when (context) - (when (debug-kernel?) - (printf "Context reference count before release: ~a~n" - (clGetContextInfo:generic (context) 'CL_CONTEXT_REFERENCE_COUNT))) - (clReleaseContext (context)) - (context #f)) - (in-opencl (sub1 (in-opencl)))))) - +;; callback function to be used for debugging clBuildProgram if we expose that +;; parameter in the opencl/c library source code. (define print-cl-build-log (λ (program _) (when (debug-kernel?) @@ -152,9 +143,14 @@ (values (/ size-rep-out size-rep0) (/ size-rep-out size-rep1)))) +(define kernel-name + (lambda (fn) + "kernel_@{(~a (eq-hash-code fn))}")) + (define (ext1-ρ-kernel prim1-ρ-f) #<bytes/utf-8 - kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0) - ;; This extra argument works only because Darshal - ;; uses a modified version of the opencl/c library - ;; which makes the clBuildProgram function accept an - ;; additional callback argument for debugging just - ;; like the original C API. - #;print-cl-build-log) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf0) - (clSetKernelArg:_cl_int kernel 1 stride0) - (clSetKernelArg:_cl_mem kernel 2 buf-out) - (clSetKernelArg:_cl_int kernel 3 stride-out)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-out stride-out)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size-out) - (vec->cpointer v-out) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-out - (clReleaseMemObject buf-out)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) + (let* ([buf0 #f] + [buf-out #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf-out (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size-out) + #f)) + (set! program (clCreateProgramWithSource (context) + (make-vector + 1 + (string->bytes/utf-8 + kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf-out) + (clSetKernelArg:_cl_int kernel 3 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf0 + (clReleaseMemObject buf0)))))) (define functional->preallocated-1-ρ-acc (λ (f-acc base-shape out-shape) @@ -247,7 +236,8 @@ EOF (define (ext1-∇-kernel prim1-∇-f) #<bytes/utf-8 - kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0) - #;print-cl-build-log) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf-g) - (clSetKernelArg:_cl_mem kernel 1 buf0) - (clSetKernelArg:_cl_int kernel 2 stride0) - (clSetKernelArg:_cl_mem kernel 3 buf-z) - (clSetKernelArg:_cl_int kernel 4 stride-z)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-z stride-z)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size0) - (vec->cpointer g0) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-g - (clReleaseMemObject buf-g)) - (when buf-z - (clReleaseMemObject buf-z)) - (when buf0 - (clReleaseMemObject buf0)))))))) + (when (debug-kernel?) + (printf "Kernel Code:~n~a~n" kernel-code)) + (let* ([buf0 #f] + [buf-z #f] + [buf-g #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf-z (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size-z) + (vref-cpointer vz offz))) + (set! buf-g (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size0) + #f)) + (set! program (clCreateProgramWithSource (context) + (make-vector + 1 + (string->bytes/utf-8 + kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf-g) + (clSetKernelArg:_cl_mem kernel 1 buf0) + (clSetKernelArg:_cl_int kernel 2 stride0) + (clSetKernelArg:_cl_mem kernel 3 buf-z) + (clSetKernelArg:_cl_int kernel 4 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-z stride-z)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size0) + (vec->cpointer g0) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g + (clReleaseMemObject buf-g)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf0 + (clReleaseMemObject buf0)))) + ;)) + )) (define functional->preallocated-1-∇-acc (λ (f-acc base-shape out-shape) @@ -348,7 +338,8 @@ EOF (let*-values (((generate-idxs) (idx-exprs strides 0 0)) ((i0-expr i1-expr) (generate-idxs "i_out"))) #<bytes/utf-8 kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0) - #;print-cl-build-log) - (set! kernel (clCreateKernel program #"Kernel")) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) (clSetKernelArg:_cl_mem kernel 0 buf0) (clSetKernelArg:_cl_int kernel 1 stride0) (clSetKernelArg:_cl_mem kernel 2 buf1) @@ -430,7 +419,7 @@ EOF (when buf1 (clReleaseMemObject buf1)) (when buf0 - (clReleaseMemObject buf0)))))))) + (clReleaseMemObject buf0)))))) (define functional->preallocated-2-ρ-acc (λ (f-acc t-shape u-shape out-shape) @@ -462,7 +451,8 @@ EOF ((i0-expr i1-expr) (generate-idxs "iz")) ((iz-expr0 iz-expr1) (generate-idxs-inv "i0" "i1" "i_rep"))) #<bytes/utf-8 kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0) - #;print-cl-build-log) - (set! kernel (clCreateKernel program #"Kernel")) - (clSetKernelArg:_cl_mem kernel 0 buf-g0) - (clSetKernelArg:_cl_mem kernel 1 buf-g1) - (clSetKernelArg:_cl_mem kernel 2 buf0) - (clSetKernelArg:_cl_int kernel 3 stride0) - (clSetKernelArg:_cl_int kernel 4 size0) - (clSetKernelArg:_cl_mem kernel 5 buf1) - (clSetKernelArg:_cl_int kernel 6 stride1) - (clSetKernelArg:_cl_int kernel 7 size1) - (clSetKernelArg:_cl_mem kernel 8 buf-z) - (clSetKernelArg:_cl_int kernel 9 stride-z)) + (vref-cpointer v1 off1))) + (set! buf-z (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size-z) + (vref-cpointer vz offz))) + (set! buf-g0 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size0) + #f)) + (set! buf-g1 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size1) + #f)) + (set! program (clCreateProgramWithSource + (context) + (make-vector 1 (string->bytes/utf-8 kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf-g0) + (clSetKernelArg:_cl_mem kernel 1 buf-g1) + (clSetKernelArg:_cl_mem kernel 2 buf0) + (clSetKernelArg:_cl_int kernel 3 stride0) + (clSetKernelArg:_cl_int kernel 4 size0) + (clSetKernelArg:_cl_mem kernel 5 buf1) + (clSetKernelArg:_cl_int kernel 6 stride1) + (clSetKernelArg:_cl_int kernel 7 size1) + (clSetKernelArg:_cl_mem kernel 8 buf-z) + (clSetKernelArg:_cl_int kernel 9 stride-z)) (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 global-work-size) @@ -590,7 +578,7 @@ EOF (when buf1 (clReleaseMemObject buf1)) (when buf0 - (clReleaseMemObject buf0)))))))) + (clReleaseMemObject buf0)))))) (define functional->preallocated-2-∇-acc (λ (f-acc t-shape u-shape out-shape) @@ -621,4 +609,5 @@ EOF (provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel - run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel) + run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel + kernel-name) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index 4613e30..e5542bd 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -215,6 +215,7 @@ (v-out (new-vec size-out 0.0))) (cond ((accelerate?) (run-prim1-ρ! (ext1-ρ-kernel f-acc) + (kernel-name f-acc) v0 off0 size0 stride0 v-out size-out stride-out)) (else @@ -242,7 +243,7 @@ (g0 (new-vec size0 0.0))) (cond - ((accelerate?) (run-prim1-∇! (ext1-∇-kernel fᵈ-acc) g0 + ((accelerate?) (run-prim1-∇! (ext1-∇-kernel fᵈ-acc) (kernel-name fᵈ-acc) g0 v0 off0 size0 stride0 vz offz size-z stride-z)) (else @@ -273,10 +274,11 @@ (λ (s-out size-out q0 q1 strides parallel-desc?) (let ((out-v (new-vec size-out 0.0))) (cond - ((accelerate?) (run-prim2-ρ! (ext2-ρ-kernel f-acc strides) - v0 off0 size0 stride0 - v1 off1 size1 stride1 - out-v size-out stride-out)) + ((accelerate?) + (run-prim2-ρ! (ext2-ρ-kernel f-acc strides) (kernel-name f-acc) + v0 off0 size0 stride0 + v1 off1 size1 stride1 + out-v size-out stride-out)) (else (for ([out-i (in-range 0 size-out stride-out)]) (let-values (((i0 i1) @@ -312,7 +314,7 @@ ((accelerate?) (let ((kernel-code (ext2-∇-kernel fᵈ-acc strides s0 s1 r0 r1 sz (length sf-z)))) - (run-prim2-∇! kernel-code + (run-prim2-∇! kernel-code (kernel-name fᵈ-acc) g0 g1 v0 off0 size0 stride0 v1 off1 size1 stride1 diff --git a/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt index e15c672..6149676 100644 --- a/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/test/test-2-acc-runtime.rkt @@ -2,8 +2,7 @@ (require rackunit) (for ((_ (in-range 100))) - (with-opencl - (λ () - (check-true (not (not (context)))) - (check-true (not (not (command-queue))))))) + (λ () + (check-true (not (not (context)))) + (check-true (not (not (command-queue)))))) ) diff --git a/impl-loader.rkt b/impl-loader.rkt index c2f5125..8c1a871 100644 --- a/impl-loader.rkt +++ b/impl-loader.rkt @@ -47,8 +47,6 @@ (debug-kernel? #f))) (when (not (settings)) - (init-settings) - (println "settings=") - (pretty-print (settings))) + (init-settings)) (provide tensor-implementation accelerate? debug-kernel?) From fe4cc4cff8678cca0dba7ca9e1871915f0c3d880 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 25 May 2024 12:01:05 -0400 Subject: [PATCH 43/83] [add-acc]Fix kernel names (WIP) --- accelerated-tensors/autodiff/B-prims.rkt | 22 ++-- accelerated-tensors/tensors/2-acc-runtime.rkt | 100 ++++++++++++------ accelerated-tensors/tensors/D-extend.rkt | 67 ++++++------ .../tensors/test/test-D-extend.rkt | 15 ++- 4 files changed, 134 insertions(+), 70 deletions(-) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index 416e254..b654ef3 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -19,6 +19,9 @@ (define shape-fn (λ (f) (f shape-fn))) +(define signature + (λ (f) (f signature))) + ;; For flat tensors, ρ-fn and ∇-fn ;; are of two types: functional and pre-allocated ;; When they are functional, they return values @@ -43,7 +46,8 @@ (define prim1 (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) - (∇-callable (ensure-∇-callable-1 ∇-fn shape))) + (∇-callable (ensure-∇-callable-1 ∇-fn shape)) + (prim-sign (symbol->string (gensym 'prim1)))) (λ (daf) (cond ((eq? daf ρ-function) ρ-fn) @@ -51,6 +55,7 @@ ((eq? daf ∇-function) ∇-fn) ((eq? daf ∇-acc-function) ∇-acc-fn) ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) (else (prim1-dual ρ-callable ∇-callable daf))))))) (define prim1-dual @@ -64,7 +69,8 @@ (define prim2 (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) - (∇-callable (ensure-∇-callable-2 ∇-fn shape))) + (∇-callable (ensure-∇-callable-2 ∇-fn shape)) + (prim-sign (symbol->string (gensym 'prim2)))) (λ ds (let ((daf (ref ds 0))) (cond @@ -73,6 +79,7 @@ ((eq? daf ∇-function) ∇-fn) ((eq? daf ∇-acc-function) ∇-acc-fn) ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) (define prim2-dual @@ -214,21 +221,24 @@ ;;---------------------------- ;; Dualized tensor op creators ;;---------------------------- + +;; TODO: Figure out the behaviour when we compose ext* with ext*. Currently we +;; assume that "f" is always a non-extended primitive. (define ext1 (λ (f n) (prim1 - (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f)) + (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f) (signature f)) (ρ-acc-function f) - (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f)) + (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f) (signature f)) (∇-acc-function f) (shape-fn f)))) (define ext2 (λ (f m n) (prim2 - (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f)) + (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f) (signature f)) (ρ-acc-function f) - (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f)) + (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f) (signature f)) (∇-acc-function f) (shape-fn f)))) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index a94bc5f..93eadcb 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -4,6 +4,7 @@ ffi/unsafe opencl/c string-interpolation + file/xxhash32 "0-vectors.rkt" "../../impl-loader.rkt") @@ -147,10 +148,10 @@ (lambda (fn) "kernel_@{(~a (eq-hash-code fn))}")) -(define (ext1-ρ-kernel prim1-ρ-f) - #<bytes/utf-8 s1)) + (xxh32-update! ctx #"_") + (xxh32-update! ctx (string->bytes/utf-8 s2)) + (xxh32-update! ctx #"_") + (xxh32-update! ctx (string->bytes/utf-8 s3)) + (xxh32-update! ctx #"#")))) + +(define (ext2-ρ-kernel-name prim-sign strides) + (define xxh32-ctx (make-xxh32)) + (xxh32-reset! xxh32-ctx 0) + (strides-signature! xxh32-ctx strides) + (define strides-hash (xxh32-digest xxh32-ctx)) + (format "~a_~a" prim-sign (~a strides-hash))) + +(define (ext2-ρ-kernel/name prim2-ρ-f prim-sign strides) (let*-values (((generate-idxs) (idx-exprs strides 0 0)) - ((i0-expr i1-expr) (generate-idxs "i_out"))) - #<bytes/utf-8 (~a s0))) + (xxh32-update! xxh32-ctx #"_") + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a s1))) + (xxh32-update! xxh32-ctx #"_") + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r0))) + (xxh32-update! xxh32-ctx #"_") + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r1))) + (xxh32-update! xxh32-ctx #"_") + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a s-out))) + (xxh32-update! xxh32-ctx #"_") + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r-out))) + (define params-hash (xxh32-digest xxh32-ctx)) + (format "~a_~a" prim-sign (~a params-hash))) + +(define (ext2-∇-kernel/name prim2-∇-f prim-sign strides + s0 s1 r0 r1 s-out r-out) (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" "v0" "i0" "stride0" "v1" "i1" "stride1" @@ -449,10 +484,12 @@ EOF ((generate-idxs-inv) (idx-exprs-inv strides 0 repeats0 repeats1 s-out)) ((i0-expr i1-expr) (generate-idxs "iz")) - ((iz-expr0 iz-expr1) (generate-idxs-inv "i0" "i1" "i_rep"))) - #<bytes/utf-8 kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0)) + (clBuildProgram program (vector (device)) (make-bytes 0) print-cl-build-log) (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) (clSetKernelArg:_cl_mem kernel 0 buf-g0) (clSetKernelArg:_cl_mem kernel 1 buf-g1) @@ -606,8 +642,8 @@ EOF (include "test/test-2-acc-runtime.rkt") -(provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel - run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel - run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel - run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel +(provide run-prim1-ρ! functional->preallocated-1-ρ-acc ext1-ρ-kernel/name + run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel/name + run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel/name + run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel/name kernel-name) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index e5542bd..dd3ad27 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -16,13 +16,13 @@ ;; to running code on the CPU . (define ext1-ρ - (λ (f f-acc m [shape-fn scalar-shape]) + (λ (f f-acc m [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e1r))]) (λ (t) (cond ((number? t) (f t)) ((expects-preallocated? f-acc) (scalarize - (flat-ext1-ρ f f-acc m shape-fn t))) + (flat-ext1-ρ f f-acc m shape-fn prim-sign t))) (else (let* ((in-shape (flat-shape t)) (base-shape (min-shape m in-shape)) @@ -30,15 +30,15 @@ (flat-f (functional->preallocated-1-ρ f base-shape out-shape)) (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape out-shape))) (scalarize - (flat-ext1-ρ flat-f flat-f-acc m shape-fn t)))))))) + (flat-ext1-ρ flat-f flat-f-acc m shape-fn prim-sign t)))))))) (define ext1-∇ - (λ (f f-acc m [shape-fn scalar-shape]) + (λ (f f-acc m [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e1n))]) (λ (t z) (cond ((number? t) (f t z)) ((expects-preallocated? f-acc) - (scalarize (flat-ext1-∇ f f-acc m shape-fn t (ensure-flat z)))) + (scalarize (flat-ext1-∇ f f-acc m shape-fn prim-sign t (ensure-flat z)))) (else (let* ((in-shape (flat-shape t)) (base-shape (min-shape m in-shape)) @@ -47,7 +47,8 @@ (flat-f-acc (functional->preallocated-1-∇-acc f-acc base-shape out-shape))) (scalarize (flat-ext1-∇ flat-f flat-f-acc m - shape-fn t (ensure-flat z))))))))) + shape-fn prim-sign + t (ensure-flat z))))))))) (define functional->preallocated-1-ρ (λ (f base-shape out-shape) @@ -100,13 +101,13 @@ ;;—————————————————–—————————————————–—————————————————– (define ext2-ρ - (λ (f f-acc m n [shape-fn scalar-shape]) + (λ (f f-acc m n [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e2r))]) (λ (t u) (cond ((and (number? t) (number? u)) (f t u)) ((expects-preallocated? f-acc) (scalarize - (flat-ext2-ρ f f-acc m n shape-fn t u))) + (flat-ext2-ρ f f-acc m n shape-fn prim-sign t u))) ((number? t) (let* ((t-shape '()) (u-shape (min-shape n (flat-shape u))) @@ -114,7 +115,7 @@ (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn (ensure-flat t) u)))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign (ensure-flat t) u)))) ((number? u) (let* ((t-shape (min-shape m (flat-shape t))) (u-shape '()) @@ -122,7 +123,7 @@ (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn t (ensure-flat u))))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t (ensure-flat u))))) (else (let* ((t-shape (min-shape m (flat-shape t))) (u-shape (min-shape n (flat-shape u))) @@ -130,14 +131,14 @@ (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn t u)))))))) + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t u)))))))) (define ext2-∇ - (λ (f f-acc m n [shape-fn scalar-shape]) + (λ (f f-acc m n [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e2n))]) (λ (t u z) (let ((invoke-flat-ext2-∇ (λ (f f-acc m n shape-fn t u z) - (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn t u z))) + (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn prim-sign t u z))) (values (scalarize da) (scalarize db)))))) (cond ((and (number? t) (number? u)) (f t u z)) @@ -200,7 +201,7 @@ out-f-shape))) (define flat-ext1-ρ - (λ (f f-acc min-rank shape-fn t0) + (λ (f f-acc min-rank shape-fn f-sign t0) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) @@ -214,10 +215,11 @@ (size-out (size-of s-out)) (v-out (new-vec size-out 0.0))) (cond - ((accelerate?) (run-prim1-ρ! (ext1-ρ-kernel f-acc) - (kernel-name f-acc) - v0 off0 size0 stride0 - v-out size-out stride-out)) + ((accelerate?) + (let-values (((kernel-code kernel-name) (ext1-ρ-kernel/name f-acc f-sign))) + (run-prim1-ρ! kernel-code kernel-name + v0 off0 size0 stride0 + v-out size-out stride-out))) (else (for ([i-out (in-range 0 size-out stride-out)]) (let ((i0 (+ off0 (* (/ i-out stride-out) stride0)))) @@ -225,7 +227,7 @@ (flat s-out v-out 0)))) (define flat-ext1-∇ - (λ (fᵈ fᵈ-acc min-rank shape-fn t0 z) + (λ (fᵈ fᵈ-acc min-rank shape-fn fᵈ-sign t0 z) ;; z has the same shape as the output (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) @@ -243,9 +245,11 @@ (g0 (new-vec size0 0.0))) (cond - ((accelerate?) (run-prim1-∇! (ext1-∇-kernel fᵈ-acc) (kernel-name fᵈ-acc) g0 - v0 off0 size0 stride0 - vz offz size-z stride-z)) + ((accelerate?) + (let-values (((kernel-code kernel-name) (ext1-∇-kernel/name fᵈ-acc fᵈ-sign))) + (run-prim1-∇! kernel-code kernel-name g0 + v0 off0 size0 stride0 + vz offz size-z stride-z))) (else (for ([iz (in-range 0 size-z stride-z)]) (let ((i0 (+ off0 (* (/ iz stride-z) stride0)))) @@ -253,7 +257,7 @@ (flat s0 g0 0)))) (define flat-ext2-ρ - (λ (f f-acc r0 r1 shape-fn t0 t1) + (λ (f f-acc r0 r1 shape-fn f-sign t0 t1) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) @@ -271,14 +275,16 @@ (stride1 (size-of sf1)) (stride-out (size-of sf-out))) (ext2-shapes s0 s1 r0 r1 sf-out + ;;TODO: get rid of "parallel-desc?" (λ (s-out size-out q0 q1 strides parallel-desc?) (let ((out-v (new-vec size-out 0.0))) (cond ((accelerate?) - (run-prim2-ρ! (ext2-ρ-kernel f-acc strides) (kernel-name f-acc) - v0 off0 size0 stride0 - v1 off1 size1 stride1 - out-v size-out stride-out)) + (let-values (((kernel-code kernel-name) (ext2-ρ-kernel/name f-acc f-sign strides))) + (run-prim2-ρ! kernel-code kernel-name + v0 off0 size0 stride0 + v1 off1 size1 stride1 + out-v size-out stride-out))) (else (for ([out-i (in-range 0 size-out stride-out)]) (let-values (((i0 i1) @@ -287,7 +293,7 @@ (flat s-out out-v 0))))))) (define flat-ext2-∇ - (λ (fᵈ fᵈ-acc r0 r1 shape-fn t0 t1 z) + (λ (fᵈ fᵈ-acc r0 r1 shape-fn fᵈ-sign t0 t1 z) (let* ((s0 (flat-shape t0)) (v0 (flat-store t0)) (off0 (flat-offset t0)) @@ -312,9 +318,10 @@ (g1 (new-vec (size-of s1) 0.0))) (cond ((accelerate?) - (let ((kernel-code (ext2-∇-kernel fᵈ-acc strides s0 s1 r0 r1 sz + (let-values (((kernel-code kernel-name) + (ext2-∇-kernel/name fᵈ-acc fᵈ-sign strides s0 s1 r0 r1 sz (length sf-z)))) - (run-prim2-∇! kernel-code (kernel-name fᵈ-acc) + (run-prim2-∇! kernel-code kernel-name g0 g1 v0 off0 size0 stride0 v1 off1 size1 stride1 diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt index f26e605..17130c4 100644 --- a/accelerated-tensors/tensors/test/test-D-extend.rkt +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -4,6 +4,7 @@ (require "A-equality.rkt") (require "B-tensor-basics.rkt") + #| (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) (vset! out-v iₒ @@ -213,7 +214,8 @@ EOF 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + |#) (module+ test (require rackunit) @@ -226,6 +228,7 @@ EOF (define +ᵈ (λ (a b z) (values z z))) (define +ᵈ-acc +ᵈ) + #| (define sqrᶠ (λ (a) (* a a))) (define sqrᵈ (λ (a z) (* z 2 a))) @@ -234,6 +237,7 @@ EOF "@{z} * 2.0 * @{a}")) (define d-sqr (ext1-∇ sqrᵈ sqrᵈ-acc 0 scalar-shape)) + |# (define one-like (λ (t) @@ -243,17 +247,22 @@ EOF (new-vec size-t 1.0) 0)))) + #| (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) (let ((gsqr (d-sqr r2-td (one-like r2-td)))) (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + |# (define d+ (ext2-∇ +ᵈ +ᵈ-acc 0 0 scalar-shape)) + #| (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) (check-tensor-equal? da (tensor 1.0 1.0 1.0)) (check-tensor-equal? db (tensor 1.0 1.0 1.0))) + |# + #; (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) (check-tensor-equal? da (tensor 2.0 2.0 2.0)) (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) @@ -267,6 +276,7 @@ EOF (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) + #| (define sum-1-∇ (λ (g t it st vz iz sz) (for* ([i (in-range it (+ it st))]) @@ -291,4 +301,5 @@ EOF (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0)))) (check-tensor-equal? gt (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0))))) + (tensor 1.0 1.0 1.0)))) + |#) From cd41c74e66d821987706fee0fbc1fc7faa7a055a Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 7 Jun 2024 20:52:10 -0400 Subject: [PATCH 44/83] [add-acc]Improve kernel name generation --- accelerated-tensors/autodiff/B-prims.rkt | 75 ++-- accelerated-tensors/tensors/2-acc-runtime.rkt | 334 ++++++++++-------- accelerated-tensors/tensors/D-extend.rkt | 205 ++++++----- .../tensors/test/test-D-extend.rkt | 15 +- 4 files changed, 337 insertions(+), 292 deletions(-) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index b654ef3..23c06b2 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -43,20 +43,27 @@ ;; results. ;; +;; Primitives need a unique identifier (its signature) which corresponds to a +;; unique GPU kernel name. However, for the graphics driver in a 2017 macbook +;; pro, the kernel name is limited to a length of 15 characters. This limitation +;; therefore limits how many unique primitives can be created. + (define prim1 - (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) - (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) - (∇-callable (ensure-∇-callable-1 ∇-fn shape)) - (prim-sign (symbol->string (gensym 'prim1)))) - (λ (daf) - (cond - ((eq? daf ρ-function) ρ-fn) - ((eq? daf ρ-acc-function) ρ-acc-fn) - ((eq? daf ∇-function) ∇-fn) - ((eq? daf ∇-acc-function) ∇-acc-fn) - ((eq? daf shape-fn) shape) - ((eq? daf signature) prim-sign) - (else (prim1-dual ρ-callable ∇-callable daf))))))) + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-1 ρ-fn shape)) + (∇-callable (ensure-∇-callable-1 ∇-fn shape)) + (prim-sign (string-append "p1" (~r id #:base 16)))) + (set! id (add1 id)) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) + ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) + (else (prim1-dual ρ-callable ∇-callable daf)))))))) (define prim1-dual (λ (ρ-fn ∇-fn da) @@ -67,20 +74,22 @@ ((κ da) da ga σ))))))) (define prim2 - (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) - (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) - (∇-callable (ensure-∇-callable-2 ∇-fn shape)) - (prim-sign (symbol->string (gensym 'prim2)))) - (λ ds - (let ((daf (ref ds 0))) - (cond - ((eq? daf ρ-function) ρ-fn) - ((eq? daf ρ-acc-function) ρ-acc-fn) - ((eq? daf ∇-function) ∇-fn) - ((eq? daf ∇-acc-function) ∇-acc-fn) - ((eq? daf shape-fn) shape) - ((eq? daf signature) prim-sign) - (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1))))))))) + (let ((id 0)) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)]) + (let ((ρ-callable (ensure-ρ-callable-2 ρ-fn shape)) + (∇-callable (ensure-∇-callable-2 ∇-fn shape)) + (prim-sign (string-append "p2" (~r id #:base 16)))) + (set! id (add1 id)) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ρ-acc-function) ρ-acc-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf ∇-acc-function) ∇-acc-fn) + ((eq? daf shape-fn) shape) + ((eq? daf signature) prim-sign) + (else (prim2-dual ρ-callable ∇-callable daf (ref ds 1)))))))))) (define prim2-dual (λ (ρ-fn ∇-fn da db) @@ -227,18 +236,22 @@ (define ext1 (λ (f n) (prim1 - (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f) (signature f)) + (ext1-ρ (ρ-function f) (ρ-acc-function f) n (shape-fn f) + (string-append "r" (signature f))) (ρ-acc-function f) - (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f) (signature f)) + (ext1-∇ (∇-function f) (∇-acc-function f) n (shape-fn f) + (string-append "n" (signature f))) (∇-acc-function f) (shape-fn f)))) (define ext2 (λ (f m n) (prim2 - (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f) (signature f)) + (ext2-ρ (ρ-function f) (ρ-acc-function f) m n (shape-fn f) + (string-append "r" (signature f))) (ρ-acc-function f) - (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f) (signature f)) + (ext2-∇ (∇-function f) (∇-acc-function f) m n (shape-fn f) + (string-append "n" (signature f))) (∇-acc-function f) (shape-fn f)))) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index 93eadcb..7624247 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -8,7 +8,9 @@ "0-vectors.rkt" "../../impl-loader.rkt") -;; TODO: Cache compiled kernels based on a unique prim name +;; TODO: Implement MNIST as an example along with iris and morse + +(define xxh32-ctx (make-xxh32)) (define context (let ([context #f]) @@ -171,6 +173,9 @@ EOF v0 off0 size0 stride0 v-out size-out stride-out) (when (debug-kernel?) + (printf "Number of GPU threads: ~a~n" (/ size-out stride-out)) + (printf "Input size: ~a~n" size0) + (printf "Output size: ~a~n" size-out) (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([buf0 #f] [buf-out #f] @@ -258,6 +263,9 @@ EOF v0 off0 size0 stride0 vz offz size-z stride-z) (when (debug-kernel?) + (printf "Number of GPU threads: ~a~n" (/ size-z stride-z)) + (printf "Input size: ~a~n" size0) + (printf "Output size: ~a~n" size-z) (printf "Kernel Code:~n~a~n" kernel-code)) (let* ([buf0 #f] [buf-z #f] @@ -334,22 +342,24 @@ EOF )))) (define (strides-signature! ctx strides) - (for ((stride-vec strides)) - (match-let* ((`#(,s1 ,s2 ,s3) (vector-map ~a stride-vec))) - (xxh32-update! ctx (string->bytes/utf-8 s1)) - (xxh32-update! ctx #"_") - (xxh32-update! ctx (string->bytes/utf-8 s2)) - (xxh32-update! ctx #"_") - (xxh32-update! ctx (string->bytes/utf-8 s3)) - (xxh32-update! ctx #"#")))) + (xxh32-update! + ctx + (for/fold + ((result #"")) + ((stride-vec strides)) + (match-let* ((`#(,s1 ,s2 ,s3) stride-vec)) + (bytes-append result + (integer->integer-bytes s1 4 #f) + (integer->integer-bytes s2 4 #f) + (integer->integer-bytes s3 4 #f)))))) (define (ext2-ρ-kernel-name prim-sign strides) - (define xxh32-ctx (make-xxh32)) (xxh32-reset! xxh32-ctx 0) (strides-signature! xxh32-ctx strides) (define strides-hash (xxh32-digest xxh32-ctx)) - (format "~a_~a" prim-sign (~a strides-hash))) + (format "~a~a" prim-sign (~r strides-hash #:base 16))) +;;TODO: Memoize this (define (ext2-ρ-kernel/name prim2-ρ-f prim-sign strides) (let*-values (((generate-idxs) (idx-exprs strides 0 0)) ((i0-expr i1-expr) (generate-idxs "i_out")) @@ -379,63 +389,67 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 v-out size-out stride-out) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) - (let* ([buf0 #f] - [buf1 #f] - [buf-out #f] - [program #f] - [kernel #f] - [event #f]) - (dynamic-wind - (λ () - (set! buf0 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size0) - (vref-cpointer v0 off0))) - (set! buf1 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size1) - (vref-cpointer v1 off1))) - (set! buf-out (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY - (* (ctype-sizeof _cl_float) - size-out) - #f)) - (set! program (clCreateProgramWithSource - (context) - (make-vector - 1 - (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0)) - (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) - (clSetKernelArg:_cl_mem kernel 0 buf0) - (clSetKernelArg:_cl_int kernel 1 stride0) - (clSetKernelArg:_cl_mem kernel 2 buf1) - (clSetKernelArg:_cl_int kernel 3 stride1) - (clSetKernelArg:_cl_mem kernel 4 buf-out) - (clSetKernelArg:_cl_int kernel 5 stride-out)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 (/ size-out stride-out)) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size-out) - (vec->cpointer v-out) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-out - (clReleaseMemObject buf-out)) - (when buf1 - (clReleaseMemObject buf1)) - (when buf0 - (clReleaseMemObject buf0)))))) + (when (debug-kernel?) + (printf "Number of GPU threads: ~a~n" (/ size-out stride-out)) + (printf "Input 0 size: ~a~n" size0) + (printf "Input 1 size: ~a~n" size1) + (printf "Output size: ~a~n" size-out) + (printf "Kernel Code:~n~a~n" kernel-code)) + (let* ([buf0 #f] + [buf1 #f] + [buf-out #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf1 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size1) + (vref-cpointer v1 off1))) + (set! buf-out (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size-out) + #f)) + (set! program (clCreateProgramWithSource + (context) + (make-vector + 1 + (string->bytes/utf-8 kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf0) + (clSetKernelArg:_cl_int kernel 1 stride0) + (clSetKernelArg:_cl_mem kernel 2 buf1) + (clSetKernelArg:_cl_int kernel 3 stride1) + (clSetKernelArg:_cl_mem kernel 4 buf-out) + (clSetKernelArg:_cl_int kernel 5 stride-out)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 (/ size-out stride-out)) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 + (* (ctype-sizeof _cl_float) + size-out) + (vec->cpointer v-out) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-out + (clReleaseMemObject buf-out)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))) (define functional->preallocated-2-ρ-acc (λ (f-acc t-shape u-shape out-shape) @@ -456,23 +470,24 @@ EOF (define (ext2-∇-kernel-name prim-sign strides s0 s1 r0 r1 s-out r-out) - (define xxh32-ctx (make-xxh32)) (xxh32-reset! xxh32-ctx 0) (strides-signature! xxh32-ctx strides) - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a s0))) - (xxh32-update! xxh32-ctx #"_") - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a s1))) - (xxh32-update! xxh32-ctx #"_") - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r0))) - (xxh32-update! xxh32-ctx #"_") - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r1))) - (xxh32-update! xxh32-ctx #"_") - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a s-out))) - (xxh32-update! xxh32-ctx #"_") - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (~a r-out))) + (xxh32-update! + xxh32-ctx + (bytes-append (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s0)) + (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s1)) + (integer->integer-bytes r0 1 #f) + (integer->integer-bytes r1 1 #f) + (apply bytes-append + (map (λ (x) (integer->integer-bytes x 4 #f)) s-out)) + (integer->integer-bytes r-out 1 #f))) (define params-hash (xxh32-digest xxh32-ctx)) - (format "~a_~a" prim-sign (~a params-hash))) + (format "~a~a" prim-sign (~r params-hash #:base 16))) + +;;TODO: Memoize this (define (ext2-∇-kernel/name prim2-∇-f prim-sign strides s0 s1 r0 r1 s-out r-out) (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" @@ -535,86 +550,91 @@ EOF v0 off0 size0 stride0 v1 off1 size1 stride1 vz offz size-z stride-z) - (when (debug-kernel?) - (printf "Kernel Code:~n~a~n" kernel-code)) - (let* ([global-work-size (max (/ size0 stride0) - (/ size1 stride1))] - [buf0 #f] - [buf1 #f] - [buf-z #f] - [buf-g0 #f] - [buf-g1 #f] - [program #f] - [kernel #f] - [event #f]) - (dynamic-wind - (λ () - (set! buf0 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (when (debug-kernel?) + (printf "Number of GPU threads: ~a~n" (max (/ size0 stride0) + (/ size1 stride1))) + (printf "Input 0 size: ~a~n" size0) + (printf "Input 1 size: ~a~n" size1) + (printf "Output size: ~a~n" size-z) + (printf "Kernel Code:~n~a~n" kernel-code)) + (let* ([global-work-size (max (/ size0 stride0) + (/ size1 stride1))] + [buf0 #f] + [buf1 #f] + [buf-z #f] + [buf-g0 #f] + [buf-g1 #f] + [program #f] + [kernel #f] + [event #f]) + (dynamic-wind + (λ () + (set! buf0 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size0) + (vref-cpointer v0 off0))) + (set! buf1 (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size1) + (vref-cpointer v1 off1))) + (set! buf-z (clCreateBuffer (context) + '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (* (ctype-sizeof _cl_float) + size-z) + (vref-cpointer vz offz))) + (set! buf-g0 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size0) + #f)) + (set! buf-g1 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY + (* (ctype-sizeof _cl_float) + size1) + #f)) + (set! program (clCreateProgramWithSource + (context) + (make-vector 1 (string->bytes/utf-8 kernel-code)))) + (clBuildProgram program (vector (device)) (make-bytes 0)) + (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) + (clSetKernelArg:_cl_mem kernel 0 buf-g0) + (clSetKernelArg:_cl_mem kernel 1 buf-g1) + (clSetKernelArg:_cl_mem kernel 2 buf0) + (clSetKernelArg:_cl_int kernel 3 stride0) + (clSetKernelArg:_cl_int kernel 4 size0) + (clSetKernelArg:_cl_mem kernel 5 buf1) + (clSetKernelArg:_cl_int kernel 6 stride1) + (clSetKernelArg:_cl_int kernel 7 size1) + (clSetKernelArg:_cl_mem kernel 8 buf-z) + (clSetKernelArg:_cl_int kernel 9 stride-z)) + (λ () + (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 + (make-vector 1 global-work-size) + (make-vector 0) + (make-vector 0))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 (* (ctype-sizeof _cl_float) size0) - (vref-cpointer v0 off0))) - (set! buf1 (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) + (vec->cpointer g0) (vector event))) + (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 (* (ctype-sizeof _cl_float) size1) - (vref-cpointer v1 off1))) - (set! buf-z (clCreateBuffer (context) - '(CL_MEM_USE_HOST_PTR CL_MEM_READ_ONLY) - (* (ctype-sizeof _cl_float) - size-z) - (vref-cpointer vz offz))) - (set! buf-g0 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY - (* (ctype-sizeof _cl_float) - size0) - #f)) - (set! buf-g1 (clCreateBuffer (context) 'CL_MEM_WRITE_ONLY - (* (ctype-sizeof _cl_float) - size1) - #f)) - (set! program (clCreateProgramWithSource - (context) - (make-vector 1 (string->bytes/utf-8 kernel-code)))) - (clBuildProgram program (vector (device)) (make-bytes 0) print-cl-build-log) - (set! kernel (clCreateKernel program (string->bytes/utf-8 ker-name))) - (clSetKernelArg:_cl_mem kernel 0 buf-g0) - (clSetKernelArg:_cl_mem kernel 1 buf-g1) - (clSetKernelArg:_cl_mem kernel 2 buf0) - (clSetKernelArg:_cl_int kernel 3 stride0) - (clSetKernelArg:_cl_int kernel 4 size0) - (clSetKernelArg:_cl_mem kernel 5 buf1) - (clSetKernelArg:_cl_int kernel 6 stride1) - (clSetKernelArg:_cl_int kernel 7 size1) - (clSetKernelArg:_cl_mem kernel 8 buf-z) - (clSetKernelArg:_cl_int kernel 9 stride-z)) - (λ () - (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 - (make-vector 1 global-work-size) - (make-vector 0) - (make-vector 0))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size0) - (vec->cpointer g0) (vector event))) - (set! event (clEnqueueReadBuffer (command-queue) buf-g1 'CL_TRUE 0 - (* (ctype-sizeof _cl_float) - size1) - (vec->cpointer g1) (vector event)))) - (λ () - (when kernel - (clReleaseKernel kernel)) - (when program - (clReleaseProgram program)) - (when buf-g1 - (clReleaseMemObject buf-g1)) - (when buf-g0 - (clReleaseMemObject buf-g0)) - (when buf-z - (clReleaseMemObject buf-z)) - (when buf1 - (clReleaseMemObject buf1)) - (when buf0 - (clReleaseMemObject buf0)))))) + (vec->cpointer g1) (vector event)))) + (λ () + (when kernel + (clReleaseKernel kernel)) + (when program + (clReleaseProgram program)) + (when buf-g1 + (clReleaseMemObject buf-g1)) + (when buf-g0 + (clReleaseMemObject buf-g0)) + (when buf-z + (clReleaseMemObject buf-z)) + (when buf1 + (clReleaseMemObject buf1)) + (when buf0 + (clReleaseMemObject buf0)))))) (define functional->preallocated-2-∇-acc (λ (f-acc t-shape u-shape out-shape) diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index dd3ad27..0bcf663 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -16,39 +16,49 @@ ;; to running code on the CPU . (define ext1-ρ - (λ (f f-acc m [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e1r))]) - (λ (t) - (cond - ((number? t) (f t)) - ((expects-preallocated? f-acc) - (scalarize - (flat-ext1-ρ f f-acc m shape-fn prim-sign t))) - (else - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-ρ f base-shape out-shape)) - (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape out-shape))) + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re1" (~r id #:base 16)))]) + (λ (t) + (cond + ((number? t) (f t)) + ((expects-preallocated? f-acc) (scalarize - (flat-ext1-ρ flat-f flat-f-acc m shape-fn prim-sign t)))))))) + (flat-ext1-ρ f f-acc m shape-fn prim-sign t))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f flat-f-acc m shape-fn prim-sign t))))))))) (define ext1-∇ - (λ (f f-acc m [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e1n))]) - (λ (t z) - (cond - ((number? t) (f t z)) - ((expects-preallocated? f-acc) - (scalarize (flat-ext1-∇ f f-acc m shape-fn prim-sign t (ensure-flat z)))) - (else - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) - (flat-f-acc (functional->preallocated-1-∇-acc - f-acc base-shape out-shape))) - (scalarize (flat-ext1-∇ flat-f flat-f-acc m - shape-fn prim-sign - t (ensure-flat z))))))))) + (let ((id -1)) + (λ (f f-acc m + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne1" (~r id #:base 16)))]) + (λ (t z) + (cond + ((number? t) (f t z)) + ((expects-preallocated? f-acc) + (scalarize (flat-ext1-∇ f f-acc m shape-fn prim-sign t (ensure-flat z)))) + (else + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc + f-acc base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f flat-f-acc m + shape-fn prim-sign + t (ensure-flat z)))))))))) (define functional->preallocated-1-ρ (λ (f base-shape out-shape) @@ -101,70 +111,80 @@ ;;—————————————————–—————————————————–—————————————————– (define ext2-ρ - (λ (f f-acc m n [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e2r))]) - (λ (t u) - (cond - ((and (number? t) (number? u)) (f t u)) - ((expects-preallocated? f-acc) - (scalarize - (flat-ext2-ρ f f-acc m n shape-fn prim-sign t u))) - ((number? t) - (let* ((t-shape '()) - (u-shape (min-shape n (flat-shape u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re2" (~r id #:base 16)))]) + (λ (t u) + (cond + ((and (number? t) (number? u)) (f t u)) + ((expects-preallocated? f-acc) (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign (ensure-flat t) u)))) - ((number? u) - (let* ((t-shape (min-shape m (flat-shape t))) - (u-shape '()) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) - (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t (ensure-flat u))))) - (else - (let* ((t-shape (min-shape m (flat-shape t))) - (u-shape (min-shape n (flat-shape u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) - (scalarize - (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t u)))))))) + (flat-ext2-ρ f f-acc m n shape-fn prim-sign t u))) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign (ensure-flat t) u)))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t (ensure-flat u))))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-ρ-acc f-acc t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f flat-f-acc m n shape-fn prim-sign t u))))))))) (define ext2-∇ - (λ (f f-acc m n [shape-fn scalar-shape] [prim-sign (symbol->string (gensym 'e2n))]) - (λ (t u z) - (let ((invoke-flat-ext2-∇ - (λ (f f-acc m n shape-fn t u z) - (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn prim-sign t u z))) - (values (scalarize da) (scalarize db)))))) - (cond - ((and (number? t) (number? u)) (f t u z)) - ((expects-preallocated? f-acc) - (invoke-flat-ext2-∇ f f-acc m n shape-fn t u z)) - ((number? t) - (let* ((t-shape '()) - (u-shape (min-shape n (flat-shape u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn (ensure-flat t) u z))) - ((number? u) - (let* ((t-shape (min-shape m (flat-shape t))) - (u-shape '()) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t (ensure-flat u) z))) - (else - (let* ((t-shape (min-shape m (flat-shape t))) - (u-shape (min-shape n (flat-shape u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) - (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) - (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t u z)))))))) + (let ((id -1)) + (λ (f f-acc m n + [shape-fn scalar-shape] + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne2" (~r id #:base 16)))]) + (λ (t u z) + (let ((invoke-flat-ext2-∇ + (λ (f f-acc m n shape-fn t u z) + (let-values (((da db) (flat-ext2-∇ f f-acc m n shape-fn prim-sign t u z))) + (values (scalarize da) (scalarize db)))))) + (cond + ((and (number? t) (number? u)) (f t u z)) + ((expects-preallocated? f-acc) + (invoke-flat-ext2-∇ f f-acc m n shape-fn t u z)) + ((number? t) + (let* ((t-shape '()) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn (ensure-flat t) u z))) + ((number? u) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape '()) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t (ensure-flat u) z))) + (else + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc f-acc t-shape u-shape out-shape))) + (invoke-flat-ext2-∇ flat-f flat-f-acc m n shape-fn t u z))))))))) (define functional->preallocated-2-ρ (λ (f t-shape u-shape out-shape) @@ -333,6 +353,7 @@ (values (flat s0 g0 0) (flat s1 g1 0)))))))) +;;TODO: Memoize this (define ext2-shapes (λ (s0 s1 r0 r1 sf-out k) (let ((l0 (length s0)) @@ -343,6 +364,8 @@ (size-of sf-out) (size-of s0) (size-of s1) + ;;TODO: Use a struct instead of a list of triples for strides. The + ;;strides struct should store the hash for the strides. '() #t)) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt index 17130c4..f26e605 100644 --- a/accelerated-tensors/tensors/test/test-D-extend.rkt +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -4,7 +4,6 @@ (require "A-equality.rkt") (require "B-tensor-basics.rkt") - #| (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) (vset! out-v iₒ @@ -214,8 +213,7 @@ EOF 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) - |#) + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) (module+ test (require rackunit) @@ -228,7 +226,6 @@ EOF (define +ᵈ (λ (a b z) (values z z))) (define +ᵈ-acc +ᵈ) - #| (define sqrᶠ (λ (a) (* a a))) (define sqrᵈ (λ (a z) (* z 2 a))) @@ -237,7 +234,6 @@ EOF "@{z} * 2.0 * @{a}")) (define d-sqr (ext1-∇ sqrᵈ sqrᵈ-acc 0 scalar-shape)) - |# (define one-like (λ (t) @@ -247,22 +243,17 @@ EOF (new-vec size-t 1.0) 0)))) - #| (check-true (equal-elements? (d-sqr r1-td (one-like r1-td)) (tensor 6.0 8.0 10.0))) (let ((gsqr (d-sqr r2-td (one-like r2-td)))) (check-tensor-equal? gsqr (reshape '(2 3) (tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) - |# (define d+ (ext2-∇ +ᵈ +ᵈ-acc 0 0 scalar-shape)) - #| (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) (check-tensor-equal? da (tensor 1.0 1.0 1.0)) (check-tensor-equal? db (tensor 1.0 1.0 1.0))) - |# - #; (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) (check-tensor-equal? da (tensor 2.0 2.0 2.0)) (check-tensor-equal? db (reshape '(2 3) (tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) @@ -276,7 +267,6 @@ EOF (check-tensor-equal? gt (tensor 1.0 2.0 3.0)) (check-tensor-equal? gu (tensor 2.0 3.0 4.0))) - #| (define sum-1-∇ (λ (g t it st vz iz sz) (for* ([i (in-range it (+ it st))]) @@ -301,5 +291,4 @@ EOF (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0)))) (check-tensor-equal? gt (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0)))) - |#) + (tensor 1.0 1.0 1.0))))) From 0d0ea1f3922c4959e942ca4d7b0936f6bbcaf383 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 8 Jun 2024 12:16:56 -0400 Subject: [PATCH 45/83] [add-acc]Memoize ext2-shapes --- accelerated-tensors/tensors/2-acc-runtime.rkt | 29 ++--- accelerated-tensors/tensors/D-extend.rkt | 106 +++++++++--------- accelerated-tensors/tensors/ext2-strides.rkt | 36 ++++++ .../tensors/test/test-D-extend.rkt | 4 +- 4 files changed, 100 insertions(+), 75 deletions(-) create mode 100644 accelerated-tensors/tensors/ext2-strides.rkt diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index 7624247..69c7931 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -6,7 +6,9 @@ string-interpolation file/xxhash32 "0-vectors.rkt" - "../../impl-loader.rkt") + "../../impl-loader.rkt" + "ext2-strides.rkt") + ;; TODO: Implement MNIST as an example along with iris and morse @@ -85,7 +87,7 @@ (for/fold ([i0 (number->string i0)] [i1 (number->string i1)] [x out-i] #:result (values i0 i1)) - ([stride strides]) + ([stride (strides-strides strides)]) (let ((stride-out (number->string (vector-ref stride 0))) (stride0 (number->string (vector-ref stride 1))) (stride1 (number->string (vector-ref stride 2)))) @@ -105,7 +107,7 @@ [predivisor-rep repeats] [x i-in-var-str] #:result i-out) ([desc-out s-out] ;; s-out == (append descents-out sf-out) - [stride strides]) ;; (len strides) == (len descents-out) + [stride (strides-strides strides)]) ;; (len strides) == (len descents-out) (let ((stride-out (vector-ref stride 0)) (stride-in (vector-ref stride stride-i))) (cond @@ -341,23 +343,8 @@ EOF EOF )))) -(define (strides-signature! ctx strides) - (xxh32-update! - ctx - (for/fold - ((result #"")) - ((stride-vec strides)) - (match-let* ((`#(,s1 ,s2 ,s3) stride-vec)) - (bytes-append result - (integer->integer-bytes s1 4 #f) - (integer->integer-bytes s2 4 #f) - (integer->integer-bytes s3 4 #f)))))) - (define (ext2-ρ-kernel-name prim-sign strides) - (xxh32-reset! xxh32-ctx 0) - (strides-signature! xxh32-ctx strides) - (define strides-hash (xxh32-digest xxh32-ctx)) - (format "~a~a" prim-sign (~r strides-hash #:base 16))) + (format "~a~a" prim-sign (strides-signature strides))) ;;TODO: Memoize this (define (ext2-ρ-kernel/name prim2-ρ-f prim-sign strides) @@ -471,7 +458,6 @@ EOF (define (ext2-∇-kernel-name prim-sign strides s0 s1 r0 r1 s-out r-out) (xxh32-reset! xxh32-ctx 0) - (strides-signature! xxh32-ctx strides) (xxh32-update! xxh32-ctx (bytes-append (apply bytes-append @@ -483,8 +469,9 @@ EOF (apply bytes-append (map (λ (x) (integer->integer-bytes x 4 #f)) s-out)) (integer->integer-bytes r-out 1 #f))) + (xxh32-update! xxh32-ctx (string->bytes/utf-8 (strides-signature strides))) (define params-hash (xxh32-digest xxh32-ctx)) - (format "~a~a" prim-sign (~r params-hash #:base 16))) + (format "~a~a" prim-sign params-hash)) ;;TODO: Memoize this diff --git a/accelerated-tensors/tensors/D-extend.rkt b/accelerated-tensors/tensors/D-extend.rkt index 0bcf663..6043280 100644 --- a/accelerated-tensors/tensors/D-extend.rkt +++ b/accelerated-tensors/tensors/D-extend.rkt @@ -6,6 +6,7 @@ (require "2-acc-runtime.ss") (require "B-tensor-basics.ss") (require "C-tensor-ops.ss") +(require "ext2-strides.rkt") ;;—————————————————–—————————————————–—————————————————– ;; Unary Pointwise extension @@ -208,7 +209,7 @@ (for/fold ([i0 i0] [i1 i1] [x out-i] #:result (values i0 i1)) - ([stride strides]) + ([stride (strides-strides strides)]) (let ((idx (quotient x (vector-ref stride 0))) (next-x (remainder x (vector-ref stride 0)))) (values (+ i0 (* idx (vector-ref stride 1))) @@ -295,8 +296,7 @@ (stride1 (size-of sf1)) (stride-out (size-of sf-out))) (ext2-shapes s0 s1 r0 r1 sf-out - ;;TODO: get rid of "parallel-desc?" - (λ (s-out size-out q0 q1 strides parallel-desc?) + (λ (s-out size-out q0 q1 strides) (let ((out-v (new-vec size-out 0.0))) (cond ((accelerate?) @@ -333,7 +333,7 @@ (vz (flat-store z)) (offz (flat-offset z))) (ext2-shapes s0 s1 r0 r1 sf-z - (λ (sz size-z q0 q1 strides parallel-desc?) + (λ (sz size-z q0 q1 strides) (let ((g0 (new-vec (size-of s0) 0.0)) (g1 (new-vec (size-of s1) 0.0))) (cond @@ -353,77 +353,79 @@ (values (flat s0 g0 0) (flat s1 g1 0)))))))) -;;TODO: Memoize this +;;TODO: Create a caching macro to generalize caching of functions (define ext2-shapes - (λ (s0 s1 r0 r1 sf-out k) - (let ((l0 (length s0)) - (l1 (length s1))) + (let ((cache (make-hash))) + (λ (s0 s1 r0 r1 sf-out k) + (define cache-key (equal-hash-code (list s0 s1 r0 r1 sf-out))) (cond - ((and (= r0 l0) (= r1 l1)) - (k sf-out - (size-of sf-out) - (size-of s0) - (size-of s1) - ;;TODO: Use a struct instead of a list of triples for strides. The - ;;strides struct should store the hash for the strides. - '() - #t)) - - ((= r0 l0) - (ext2-shapes s0 (cdr s1) r0 r1 sf-out - (desc-right (car s1) k))) - - ((= r1 l1) - (ext2-shapes (cdr s0) s1 r0 r1 sf-out - (desc-left (car s0) k))) - - ((and (not (null? s0)) - (not (null? s1)) - (= (car s0) (car s1))) - (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out - (desc-both (car s0) k))) - - ((> l1 l0) - (ext2-shapes s0 (cdr s1) r0 r1 sf-out - (desc-right (car s1) k))) - - ((> l0 l1) - (ext2-shapes (cdr s0) s1 r0 r1 sf-out - (desc-left (car s0) k))) - - (else (error 'ext - "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" - s0 s1 r0 r1)))))) + [(hash-has-key? cache cache-key) (apply k (hash-ref cache cache-key))] + [else + (let ((l0 (length s0)) + (l1 (length s1)) + (k (λ args + (hash-set! cache cache-key args) + (apply k args)))) + (cond + ((and (= r0 l0) (= r1 l1)) + (k sf-out + (size-of sf-out) + (size-of s0) + (size-of s1) + strides-null)) + + ((= r0 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((= r1 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + ((and (not (null? s0)) + (not (null? s1)) + (= (car s0) (car s1))) + (ext2-shapes (cdr s0) (cdr s1) r0 r1 sf-out + (desc-both (car s0) k))) + + ((> l1 l0) + (ext2-shapes s0 (cdr s1) r0 r1 sf-out + (desc-right (car s1) k))) + + ((> l0 l1) + (ext2-shapes (cdr s0) s1 r0 r1 sf-out + (desc-left (car s0) k))) + + (else (error 'ext + "Shapes are incompatible for ext2: ~a, and ~a for min ranks ~a, and ~a~%" + s0 s1 r0 r1))))])))) (define desc-both (λ (d k) - (λ (s-out qout q0 q1 strides parallel-desc?) + (λ (s-out qout q0 q1 strides) (k (cons d s-out) (* qout d) (* q0 d) (* q1 d) - (cons (vector qout q0 q1) strides) - parallel-desc?)))) + (strides-cons qout q0 q1 strides))))) (define desc-left (λ (d k) - (λ (s-out qout q0 q1 strides parallel-desc?) + (λ (s-out qout q0 q1 strides) (k (cons d s-out) (* qout d) (* q0 d) q1 - (cons (vector qout q0 0) strides) - #f)))) + (strides-cons qout q0 0 strides))))) (define desc-right (λ (d k) - (λ (s-out qout q0 q1 strides parallel-desc?) + (λ (s-out qout q0 q1 strides) (k (cons d s-out) (* qout d) q0 (* q1 d) - (cons (vector qout 0 q1) strides) - #f)))) + (strides-cons qout 0 q1 strides))))) (define v-copy-flat! (λ (vg ig a) diff --git a/accelerated-tensors/tensors/ext2-strides.rkt b/accelerated-tensors/tensors/ext2-strides.rkt new file mode 100644 index 0000000..abdd1c2 --- /dev/null +++ b/accelerated-tensors/tensors/ext2-strides.rkt @@ -0,0 +1,36 @@ +#lang racket + +(require file/xxhash32) + +(provide strides-null + strides-cons + (rename-out (ext2-strides-strides strides-strides) + (ext2-strides-sign strides-signature))) + +(define (strides-signature! ctx strides) + (xxh32-update! + ctx + (for/fold + ((result #"")) + ((stride-vec strides)) + (match-let* ((`#(,s1 ,s2 ,s3) stride-vec)) + (bytes-append result + (integer->integer-bytes s1 4 #f) + (integer->integer-bytes s2 4 #f) + (integer->integer-bytes s3 4 #f)))))) + +(struct ext2-strides ((strides #:mutable) sign)) + +(define strides-null + (ext2-strides '() (let ((ctx (make-xxh32))) + (~r (xxh32-digest ctx) #:base 16)))) + +(define strides-cons + (λ (st-out st0 st1 strides) + (let ((new-list (cons (vector st-out st0 st1) + (ext2-strides-strides strides)))) + (ext2-strides new-list + (begin + (let ((ctx (make-xxh32))) + (strides-signature! ctx new-list) + (~r (xxh32-digest ctx) #:base 16))))))) diff --git a/accelerated-tensors/tensors/test/test-D-extend.rkt b/accelerated-tensors/tensors/test/test-D-extend.rkt index f26e605..78ba2a1 100644 --- a/accelerated-tensors/tensors/test/test-D-extend.rkt +++ b/accelerated-tensors/tensors/test/test-D-extend.rkt @@ -94,10 +94,10 @@ EOF (define r1 1) (ext2-shapes s0 s1 r0 r1 '(5 6) - (λ (s-out size-out q0 q1 strides parallel-desc?) + (λ (s-out size-out q0 q1 strides) (check-equal? s-out '(3 4 7 5 6)) (check-equal? size-out 2520) - (check-equal? strides '(#(840 120 42) #(210 30 0) #(30 0 6))) + (check-equal? (strides-strides strides) '(#(840 120 42) #(210 30 0) #(30 0 6))) (let-values (((i0 i1) (idxs strides 0 0 0))) (check-equal? i0 0) (check-equal? i1 0)) From 4486a357b2f97f3aba5d28860bc5c97e5d4e6174 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 12 Jun 2024 01:46:45 -0400 Subject: [PATCH 46/83] [add-acc]Memoize ext2-*-kernel/name --- accelerated-tensors/tensors/2-acc-runtime.rkt | 73 ++++++++++++------- 1 file changed, 45 insertions(+), 28 deletions(-) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index 69c7931..da89054 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -346,13 +346,18 @@ EOF (define (ext2-ρ-kernel-name prim-sign strides) (format "~a~a" prim-sign (strides-signature strides))) -;;TODO: Memoize this -(define (ext2-ρ-kernel/name prim2-ρ-f prim-sign strides) - (let*-values (((generate-idxs) (idx-exprs strides 0 0)) - ((i0-expr i1-expr) (generate-idxs "i_out")) - ((kernel-name) (ext2-ρ-kernel-name prim-sign strides))) - (values - #<bytes/utf-8 (strides-signature strides))) (xxh32-update! xxh32-ctx (bytes-append (apply bytes-append @@ -469,28 +478,32 @@ EOF (apply bytes-append (map (λ (x) (integer->integer-bytes x 4 #f)) s-out)) (integer->integer-bytes r-out 1 #f))) - (xxh32-update! xxh32-ctx (string->bytes/utf-8 (strides-signature strides))) (define params-hash (xxh32-digest xxh32-ctx)) (format "~a~a" prim-sign params-hash)) -;;TODO: Memoize this -(define (ext2-∇-kernel/name prim2-∇-f prim-sign strides - s0 s1 r0 r1 s-out r-out) - (let*-values (((prim-effect0 prim-effect1) (prim2-∇-f "g" - "v0" "i0" "stride0" - "v1" "i1" "stride1" - "vz" "iz" "stride_z")) - ((repeats0 repeats1) (calc-repeats s0 s1 r0 r1 s-out r-out)) - ((generate-idxs) (idx-exprs strides 0 0)) - ((generate-idxs-inv) (idx-exprs-inv strides 0 - repeats0 repeats1 s-out)) - ((i0-expr i1-expr) (generate-idxs "iz")) - ((iz-expr0 iz-expr1) (generate-idxs-inv "i0" "i1" "i_rep")) - ((kernel-name) (ext2-∇-kernel-name prim-sign strides - s0 s1 r0 r1 s-out r-out))) - (values - #< Date: Tue, 16 Jul 2024 20:30:56 -0400 Subject: [PATCH 47/83] [add-acc]Add zeroes as a primitive --- accelerated-tensors.rkt | 2 +- accelerated-tensors/ext-ops.rkt | 2 +- accelerated-tensors/ext-ops/A-scalar-ops.rkt | 5 ++++- accelerated-tensors/no-duals-no-overrides.rkt | 2 +- accelerated-tensors/no-duals.rkt | 2 +- accelerated-tensors/no-overrides.rkt | 2 +- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/accelerated-tensors.rkt b/accelerated-tensors.rkt index 5fa0292..9132390 100644 --- a/accelerated-tensors.rkt +++ b/accelerated-tensors.rkt @@ -33,7 +33,7 @@ (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/accelerated-tensors/ext-ops.rkt b/accelerated-tensors/ext-ops.rkt index 83af7de..fc223f5 100644 --- a/accelerated-tensors/ext-ops.rkt +++ b/accelerated-tensors/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/accelerated-tensors/ext-ops/A-scalar-ops.rkt b/accelerated-tensors/ext-ops/A-scalar-ops.rkt index 7893190..72b33f9 100644 --- a/accelerated-tensors/ext-ops/A-scalar-ops.rkt +++ b/accelerated-tensors/ext-ops/A-scalar-ops.rkt @@ -193,6 +193,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) (λ (_) "0.0") 0)) + (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 @@ -204,4 +207,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/accelerated-tensors/no-duals-no-overrides.rkt b/accelerated-tensors/no-duals-no-overrides.rkt index ac07a7a..07ca22e 100644 --- a/accelerated-tensors/no-duals-no-overrides.rkt +++ b/accelerated-tensors/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ concat-ρ concat-n-ρ flatten-ρ diff --git a/accelerated-tensors/no-duals.rkt b/accelerated-tensors/no-duals.rkt index 927c8c7..cd1bcaf 100644 --- a/accelerated-tensors/no-duals.rkt +++ b/accelerated-tensors/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/accelerated-tensors/no-overrides.rkt b/accelerated-tensors/no-overrides.rkt index 35dcbdd..05844b7 100644 --- a/accelerated-tensors/no-overrides.rkt +++ b/accelerated-tensors/no-overrides.rkt @@ -34,7 +34,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ From c748d73552ea2b68be9ef791a17083191b7197f9 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 25 Jun 2024 23:04:22 -0400 Subject: [PATCH 48/83] [add-acc]Add a work size parameter for opencl --- accelerated-tensors/autodiff/B-prims.rkt | 2 +- accelerated-tensors/tensors/2-acc-runtime.rkt | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index 23c06b2..3166b35 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -135,7 +135,7 @@ (cond ((expects-preallocated? ∇-fn) (λ (ra rb z) - (apply-flat-∇-fn-1 ∇-fn ra rb z shape-fn))) + (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn))) (else ∇-fn)))) (define apply-flat-ρ-fn-1 diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index da89054..df8627e 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -12,6 +12,7 @@ ;; TODO: Implement MNIST as an example along with iris and morse +(define local-work-size (make-parameter #f)) (define xxh32-ctx (make-xxh32)) (define context @@ -207,9 +208,10 @@ EOF (clSetKernelArg:_cl_mem kernel 2 buf-out) (clSetKernelArg:_cl_int kernel 3 stride-out)) (λ () + ;;TODO: Try using the local-work-size argument (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-out stride-out)) - (make-vector 0) + (or (local-work-size) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -306,7 +308,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-z stride-z)) - (make-vector 0) + (or (local-work-size) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -428,7 +430,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-out stride-out)) - (make-vector 0) + (or (local-work-size) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -614,7 +616,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 global-work-size) - (make-vector 0) + (or (local-work-size) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 (* (ctype-sizeof _cl_float) From 61da49256d28cb765ed84e830a856e09c2a56889 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 27 Jun 2024 18:37:38 -0400 Subject: [PATCH 49/83] [add-acc]Fix setting local work size during kernel exec --- accelerated-tensors/tensors/2-acc-runtime.rkt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/accelerated-tensors/tensors/2-acc-runtime.rkt b/accelerated-tensors/tensors/2-acc-runtime.rkt index df8627e..f815a0b 100644 --- a/accelerated-tensors/tensors/2-acc-runtime.rkt +++ b/accelerated-tensors/tensors/2-acc-runtime.rkt @@ -211,7 +211,7 @@ EOF ;;TODO: Try using the local-work-size argument (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-out stride-out)) - (or (local-work-size) (make-vector 0)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -308,7 +308,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-z stride-z)) - (or (local-work-size) (make-vector 0)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-g 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -430,7 +430,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 (/ size-out stride-out)) - (or (local-work-size) (make-vector 0)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-out 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -616,7 +616,7 @@ EOF (λ () (set! event (clEnqueueNDRangeKernel (command-queue) kernel 1 (make-vector 1 global-work-size) - (or (local-work-size) (make-vector 0)) + (if (local-work-size) (make-vector 1 (local-work-size)) (make-vector 0)) (make-vector 0))) (set! event (clEnqueueReadBuffer (command-queue) buf-g0 'CL_TRUE 0 (* (ctype-sizeof _cl_float) @@ -672,4 +672,4 @@ EOF run-prim1-∇! functional->preallocated-1-∇-acc ext1-∇-kernel/name run-prim2-ρ! functional->preallocated-2-ρ-acc ext2-ρ-kernel/name run-prim2-∇! functional->preallocated-2-∇-acc ext2-∇-kernel/name - kernel-name) + kernel-name local-work-size) From 8b1827ccee4e6049d3f8a657f86c5633b22d9d73 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:39:46 -0400 Subject: [PATCH 50/83] [add-acc]Switch compiled tensor runtime to acc tensor impl --- accelerated-tensors/autodiff/B-prims.rkt | 6 +++++- accelerated-tensors/ext-impl.rkt | 14 ++++++++++++++ accelerated-tensors/tensors/B-tensor-basics.rkt | 2 +- 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index 3166b35..8259fbf 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -255,4 +255,8 @@ (∇-acc-function f) (shape-fn f)))) -(provide prim1 prim2 ext1 ext2) +(provide prim1 prim2 ext1 ext2 + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2) diff --git a/accelerated-tensors/ext-impl.rkt b/accelerated-tensors/ext-impl.rkt index 6eb7d34..b93c110 100644 --- a/accelerated-tensors/ext-impl.rkt +++ b/accelerated-tensors/ext-impl.rkt @@ -1,6 +1,9 @@ #lang racket (require "tensors/0-vectors.rkt") (require "tensors/1-flats.rkt") +(require (only-in "tensors/2-acc-runtime.rkt" + ext2-∇-kernel/name + run-prim2-∇!)) (require (only-in "tensors/B-tensor-basics.rkt" merge-flats)) (require (only-in "tensors/D-extend.rkt" @@ -14,15 +17,26 @@ functional->preallocated-1-∇ functional->preallocated-2-ρ functional->preallocated-2-∇ + functional->preallocated-1-ρ-acc + functional->preallocated-1-∇-acc + functional->preallocated-2-ρ-acc + functional->preallocated-2-∇-acc idxs scalarize ensure-flat)) +(require (only-in "autodiff/B-prims.rkt" + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2)) (require (only-in "autodiff/E-print.rkt" make-printable-flat fake-tensor)) (provide (all-from-out "tensors/0-vectors.rkt")) (provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/2-acc-runtime.rkt")) (provide (all-from-out "tensors/B-tensor-basics.rkt")) (provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/B-prims.rkt")) (provide (all-from-out "autodiff/E-print.rkt")) diff --git a/accelerated-tensors/tensors/B-tensor-basics.rkt b/accelerated-tensors/tensors/B-tensor-basics.rkt index 63e9a71..39ac818 100644 --- a/accelerated-tensors/tensors/B-tensor-basics.rkt +++ b/accelerated-tensors/tensors/B-tensor-basics.rkt @@ -71,7 +71,7 @@ (cond ((null? lst) (error 'list->flat-tensor "No elements found")) ((number? (car lst)) - (flat (list (length lst)) (list->vec lst) 0)) + (flat (list (length lst)) (list->vec (map exact->inexact lst)) 0)) (else (flat-tensor-from-list lst))))) From 7d16ce3562515a8c1d20ce8f1a0b07c9658b2201 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:44:53 -0400 Subject: [PATCH 51/83] [add-acc]Add primitive ast nodes to compiled impl --- accelerated-tensors/autodiff/B-prims.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/accelerated-tensors/autodiff/B-prims.rkt b/accelerated-tensors/autodiff/B-prims.rkt index 8259fbf..044b2cb 100644 --- a/accelerated-tensors/autodiff/B-prims.rkt +++ b/accelerated-tensors/autodiff/B-prims.rkt @@ -178,7 +178,7 @@ (λ (ρ-fn ra rb shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) @@ -202,7 +202,7 @@ (λ (∇-fn ra rb z shape-fn) (let* ((in-shape-a (flat-shape ra)) (in-size-a (size-of in-shape-a)) - (in-shape-b (flat-shape ra)) + (in-shape-b (flat-shape rb)) (in-size-b (size-of in-shape-b)) (out-shape (shape-fn in-shape-a in-shape-b)) (out-size (size-of out-shape))) From 1fa22d12fcfa8608a9af70f3bd3fe0e638b21f30 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 20:12:26 -0400 Subject: [PATCH 52/83] [add-lazy]Add lazy tensor implementation --- Makefile | 43 ++ flat-tensors/autodiff/E-print.rkt | 5 +- flat-tensors/ext-impl.rkt | 22 + flat-tensors/tensors/B-tensor-basics.rkt | 2 +- flat-tensors/tensors/D-extend.rkt | 4 +- impl-no-duals-no-overrides.rkt | 2 + impl-no-duals.rkt | 2 + impl-no-overrides.rkt | 2 + impl.rkt | 2 + lazy.rkt | 41 ++ lazy/autodiff.rkt | 22 + lazy/autodiff/A-autodiff.rkt | 120 ++++++ lazy/autodiff/B-prims.rkt | 72 ++++ lazy/autodiff/C-dualized-tensor-ops.rkt | 47 +++ lazy/autodiff/D-test-helpers.rkt | 53 +++ lazy/autodiff/E-print.rkt | 29 ++ lazy/autodiff/test/test-A-autodiff.rkt | 15 + lazy/autodiff/test/test-E-print.rkt | 71 ++++ lazy/ext-ops.rkt | 33 ++ lazy/ext-ops/A-scalar-ops.rkt | 133 ++++++ lazy/ext-ops/B-comparators.rkt | 85 ++++ lazy/ext-ops/C-star-2-1.rkt | 43 ++ lazy/ext-ops/D-sum.rkt | 67 +++ lazy/ext-ops/E-argmax.rkt | 40 ++ lazy/ext-ops/F-max.rkt | 48 +++ lazy/ext-ops/G-correlate.rkt | 94 +++++ lazy/ext-ops/test/test-A-scalar-ops.rkt | 122 ++++++ lazy/ext-ops/test/test-B-comparators.rkt | 12 + lazy/ext-ops/test/test-C-star-2-1.rkt | 24 ++ lazy/ext-ops/test/test-D-sum.rkt | 53 +++ lazy/ext-ops/test/test-E-argmax.rkt | 17 + lazy/ext-ops/test/test-F-max.rkt | 10 + lazy/ext-ops/test/test-G-correlate.rkt | 118 ++++++ lazy/no-duals-no-overrides.rkt | 28 ++ lazy/no-duals.rkt | 28 ++ lazy/no-overrides.rkt | 37 ++ lazy/tensors.rkt | 19 + lazy/tensors/0-lazy.rkt | 494 +++++++++++++++++++++++ lazy/tensors/A-equality.rkt | 18 + lazy/tensors/test/test-0-lazy.rkt | 227 +++++++++++ lazy/tensors/test/test-A-equality.rkt | 49 +++ malted/test/test-O-init.rkt | 10 +- 42 files changed, 2356 insertions(+), 7 deletions(-) create mode 100644 flat-tensors/ext-impl.rkt create mode 100644 lazy.rkt create mode 100644 lazy/autodiff.rkt create mode 100644 lazy/autodiff/A-autodiff.rkt create mode 100644 lazy/autodiff/B-prims.rkt create mode 100644 lazy/autodiff/C-dualized-tensor-ops.rkt create mode 100644 lazy/autodiff/D-test-helpers.rkt create mode 100644 lazy/autodiff/E-print.rkt create mode 100644 lazy/autodiff/test/test-A-autodiff.rkt create mode 100644 lazy/autodiff/test/test-E-print.rkt create mode 100644 lazy/ext-ops.rkt create mode 100644 lazy/ext-ops/A-scalar-ops.rkt create mode 100644 lazy/ext-ops/B-comparators.rkt create mode 100644 lazy/ext-ops/C-star-2-1.rkt create mode 100644 lazy/ext-ops/D-sum.rkt create mode 100644 lazy/ext-ops/E-argmax.rkt create mode 100644 lazy/ext-ops/F-max.rkt create mode 100644 lazy/ext-ops/G-correlate.rkt create mode 100644 lazy/ext-ops/test/test-A-scalar-ops.rkt create mode 100644 lazy/ext-ops/test/test-B-comparators.rkt create mode 100644 lazy/ext-ops/test/test-C-star-2-1.rkt create mode 100644 lazy/ext-ops/test/test-D-sum.rkt create mode 100644 lazy/ext-ops/test/test-E-argmax.rkt create mode 100644 lazy/ext-ops/test/test-F-max.rkt create mode 100644 lazy/ext-ops/test/test-G-correlate.rkt create mode 100644 lazy/no-duals-no-overrides.rkt create mode 100644 lazy/no-duals.rkt create mode 100644 lazy/no-overrides.rkt create mode 100644 lazy/tensors.rkt create mode 100644 lazy/tensors/0-lazy.rkt create mode 100644 lazy/tensors/A-equality.rkt create mode 100644 lazy/tensors/test/test-0-lazy.rkt create mode 100644 lazy/tensors/test/test-A-equality.rkt diff --git a/Makefile b/Makefile index b5b0bfb..a6fc7d0 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,7 @@ TEST_FLAGS=-q # add sources be sure to update this section. #---------------------------------------- +LAZY_DIR=lazy LEARNER_DIR=learner FLAT_DIR=flat-tensors UNIFORM_DIR=uniform-tensors @@ -18,6 +19,46 @@ NESTED_DIR=nested-tensors TOOLS_DIR=tools MALTED_DIR=malted +# lazy +LAZY_TENSORS_DIR=$(LAZY_DIR)/tensors +LAZY_AUTODIFF_DIR=$(LAZY_DIR)/autodiff +LAZY_EXT_OPS_DIR=$(LAZY_DIR)/ext-ops + +LAZY_TENSORS_SOURCES=\ + $(LAZY_TENSORS_DIR)/0-lazy.rkt\ + $(LAZY_TENSORS_DIR)/A-equality.rkt\ + $(LAZY_DIR)/tensors.rkt + +LAZY_AUTODIFF_SOURCES=\ + $(LAZY_AUTODIFF_DIR)/A-autodiff.rkt\ + $(LAZY_AUTODIFF_DIR)/B-prims.rkt\ + $(LAZY_AUTODIFF_DIR)/C-dualized-tensor-ops.rkt\ + $(LAZY_AUTODIFF_DIR)/D-test-helpers.rkt\ + $(LAZY_DIR)/autodiff.rkt + +LAZY_EXT_OPS_SOURCES=\ + $(LAZY_EXT_OPS_DIR)/A-scalar-ops.rkt\ + $(LAZY_EXT_OPS_DIR)/B-comparators.rkt\ + $(LAZY_EXT_OPS_DIR)/C-star-2-1.rkt\ + $(LAZY_EXT_OPS_DIR)/D-sum.rkt\ + $(LAZY_EXT_OPS_DIR)/E-argmax.rkt\ + $(LAZY_EXT_OPS_DIR)/F-max.rkt\ + $(LAZY_EXT_OPS_DIR)/G-correlate.rkt\ + $(LAZY_DIR)/ext-ops.rkt + +LAZY_LOADERS=\ + $(LAZY_DIR)/no-duals-no-overrides.rkt\ + $(LAZY_DIR)/no-duals.rkt\ + $(LAZY_DIR)/no-overrides.rkt\ + $(LAZY_DIR)/tensors.rkt\ + $(LAZY_DIR)/autodiff.rkt\ + $(LAZY_DIR)/ext-ops.rkt + +LAZY_SOURCES=$(LAZY_TENSORS_SOURCES)\ + $(LAZY_AUTODIFF_SOURCES)\ + $(LAZY_EXT_OPS_SOURCES)\ + $(LAZY_LOADERS) + # learner LEARNER_TENSORS_DIR=$(LEARNER_DIR)/tensors LEARNER_AUTODIFF_DIR=$(LEARNER_DIR)/autodiff @@ -104,6 +145,7 @@ FLAT_LOADERS=\ $(FLAT_DIR)/no-overrides.rkt\ $(FLAT_DIR)/tensors.rkt\ $(FLAT_DIR)/autodiff.rkt\ + $(FLAT_DIR)/ext-impl.rkt\ $(FLAT_DIR)/ext-ops.rkt FLAT_SOURCES=$(FLAT_TENSORS_SOURCES)\ @@ -281,6 +323,7 @@ MALTED_SOURCES=\ # All the sources together, plus entry points SOURCES=$(LEARNER_SOURCES)\ + $(LAZY_SOURCES)\ $(FLAT_SOURCES)\ $(UNIFORM_SOURCES)\ $(ACCELERATED_SOURCES)\ diff --git a/flat-tensors/autodiff/E-print.rkt b/flat-tensors/autodiff/E-print.rkt index cd0b5d7..ca3ce62 100644 --- a/flat-tensors/autodiff/E-print.rkt +++ b/flat-tensors/autodiff/E-print.rkt @@ -80,4 +80,7 @@ (include "test/test-E-print.rkt") (provide max-tensor-print-length - make-printable) + make-printable + ;; This is used in ext-impl.rkt + make-printable-flat + fake-tensor) diff --git a/flat-tensors/ext-impl.rkt b/flat-tensors/ext-impl.rkt new file mode 100644 index 0000000..b79efe3 --- /dev/null +++ b/flat-tensors/ext-impl.rkt @@ -0,0 +1,22 @@ +#lang racket +(require "tensors/0-vectors.rkt") +(require "tensors/1-flats.rkt") +(require (only-in "tensors/B-tensor-basics.rkt" + merge-flats)) +(require (only-in "tensors/D-extend.rkt" + merge-shapes + min-shape + ext2-shapes + flat-ext1-∇ + flat-ext1-ρ + flat-ext2-ρ + idxs)) +(require (only-in "autodiff/E-print.rkt" + make-printable-flat + fake-tensor)) + +(provide (all-from-out "tensors/0-vectors.rkt")) +(provide (all-from-out "tensors/1-flats.rkt")) +(provide (all-from-out "tensors/B-tensor-basics.rkt")) +(provide (all-from-out "tensors/D-extend.rkt")) +(provide (all-from-out "autodiff/E-print.rkt")) diff --git a/flat-tensors/tensors/B-tensor-basics.rkt b/flat-tensors/tensors/B-tensor-basics.rkt index 4af5d69..60df715 100644 --- a/flat-tensors/tensors/B-tensor-basics.rkt +++ b/flat-tensors/tensors/B-tensor-basics.rkt @@ -179,4 +179,4 @@ (include "test/test-B-tensor-basics.rkt") (provide tref tlen list->tensor number? - tensor? tensor build-tensor trefs) + tensor? tensor build-tensor trefs merge-flats) diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index d2bde38..8d6c068 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -385,4 +385,6 @@ (include "test/test-D-extend.rkt") -(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated?) +(provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + merge-shapes min-shape ext2-shapes idxs + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ) diff --git a/impl-no-duals-no-overrides.rkt b/impl-no-duals-no-overrides.rkt index 8073099..c01c9a6 100644 --- a/impl-no-duals-no-overrides.rkt +++ b/impl-no-duals-no-overrides.rkt @@ -11,12 +11,14 @@ (printf "Tensor implementation (no-duals, no-overrides): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-duals-no-overrides.rkt")) ((learner) #'(require "learner/no-duals-no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals-no-overrides.rkt")) ((uniform-tensors) #'(require "uniform-tensors/no-duals-no-overrides.rkt")) ((accelerated-tensors) #'(require "accelerated-tensors/no-duals-no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals-no-overrides.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-duals-no-overrides.rkt"))) ((learner) #'(provide (all-from-out "learner/no-duals-no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals-no-overrides.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals-no-overrides.rkt"))) diff --git a/impl-no-duals.rkt b/impl-no-duals.rkt index ba72d04..a90faf4 100644 --- a/impl-no-duals.rkt +++ b/impl-no-duals.rkt @@ -11,12 +11,14 @@ (printf "Tensor implementation (no-duals): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-duals.rkt")) ((learner) #'(require "learner/no-duals.rkt")) ((flat-tensors) #'(require "flat-tensors/no-duals.rkt")) ((uniform-tensors) #'(require "uniform-tensors/no-duals.rkt")) ((accelerated-tensors) #'(require "accelerated-tensors/no-duals.rkt")) ((nested-tensors) #'(require "nested-tensors/no-duals.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-duals.rkt"))) ((learner) #'(provide (all-from-out "learner/no-duals.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-duals.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-duals.rkt"))) diff --git a/impl-no-overrides.rkt b/impl-no-overrides.rkt index b1e6669..4ece6b3 100644 --- a/impl-no-overrides.rkt +++ b/impl-no-overrides.rkt @@ -11,12 +11,14 @@ (printf "Tensor implementation (no-overrides): ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy/no-overrides.rkt")) ((learner) #'(require "learner/no-overrides.rkt")) ((flat-tensors) #'(require "flat-tensors/no-overrides.rkt")) ((uniform-tensors) #'(require "uniform-tensors/no-overrides.rkt")) ((accelerated-tensors) #'(require "accelerated-tensors/no-overrides.rkt")) ((nested-tensors) #'(require "nested-tensors/no-overrides.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy/no-overrides.rkt"))) ((learner) #'(provide (all-from-out "learner/no-overrides.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors/no-overrides.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors/no-overrides.rkt"))) diff --git a/impl.rkt b/impl.rkt index 3e47f85..bf1c1ca 100644 --- a/impl.rkt +++ b/impl.rkt @@ -11,12 +11,14 @@ (printf "Tensor implementation: ~s~%" (tensor-implementation)) #`(begin #,(case (tensor-implementation) + ((lazy) #'(require "lazy.rkt")) ((learner) #'(require "learner.rkt")) ((flat-tensors) #'(require "flat-tensors.rkt")) ((uniform-tensors) #'(require "uniform-tensors.rkt")) ((accelerated-tensors) #'(require "accelerated-tensors.rkt")) ((nested-tensors) #'(require "nested-tensors.rkt"))) #,(case (tensor-implementation) + ((lazy) #'(provide (all-from-out "lazy.rkt"))) ((learner) #'(provide (all-from-out "learner.rkt"))) ((flat-tensors) #'(provide (all-from-out "flat-tensors.rkt"))) ((uniform-tensors) #'(provide (all-from-out "uniform-tensors.rkt"))) diff --git a/lazy.rkt b/lazy.rkt new file mode 100644 index 0000000..5fdf862 --- /dev/null +++ b/lazy.rkt @@ -0,0 +1,41 @@ +#lang racket/base + +(require + (except-in "lazy/tensors.rkt" + rank shape reshape trefs tensor? tlen ref refr)) + +(require "lazy/autodiff.rkt") +(require "lazy/ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print make-printable max-tensor-print-length check-dual-equal? check-ρ-∇ + + (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) + (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) + (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate)) + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 + + sum-1 argmax-1 max-1 + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/autodiff.rkt b/lazy/autodiff.rkt new file mode 100644 index 0000000..66f8f3a --- /dev/null +++ b/lazy/autodiff.rkt @@ -0,0 +1,22 @@ +#lang racket + +(require "autodiff/A-autodiff.rkt") +(require "autodiff/B-prims.rkt") +(require "autodiff/C-dualized-tensor-ops.rkt") +(require "autodiff/D-test-helpers.rkt") +(require "autodiff/E-print.rkt") + +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual*) +(provide prim1 prim2 ext1 ext2) +(provide (rename-out (d-rank rank) + (d-shape shape) + (d-reshape reshape) + (d-trefs trefs) + (d-tensor? tensor?) + (d-tlen tlen) + (d-ref ref) + (d-refr refr))) + +(provide check-dual-equal? check-ρ-∇) + +(provide make-printable max-tensor-print-length) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt new file mode 100644 index 0000000..51a54ef --- /dev/null +++ b/lazy/autodiff/A-autodiff.rkt @@ -0,0 +1,120 @@ +#lang racket + +(require "../tensors.rkt") + +;;---------------------------- +;; Real part of a dual is always a tensor (of any rank) +;;---------------------------- + +(define dual? + (λ (x) + (and (vector? x) (eq? (vector-ref x 0) dual)))) + +(define dual + (λ (r k) + (vector dual r k))) + +(define dual* + (λ (d) + (dual (ρ d) end-of-chain))) + +(define ρ + (λ (d) + (cond + ((dual? d) (vector-ref d 1)) + (else d)))) + +(define κ + (λ (d) + (cond + ((dual? d) (vector-ref d 2)) + (else end-of-chain)))) + +(define scalar? + (λ (d) + (or (number? d) + (and (dual? d) + (number? (ρ d)))))) + +(define dual-like? + (λ (d) + (or (dual? d) + (number? d) + (tensor? d)))) + +;;---------------------------- +;; Chain rule +;;---------------------------- + +(define end-of-chain + (λ (d z σ) + (let ((g (hash-ref σ d 0.0))) + (hash-set σ d (+-ρ z g))))) + +(define +-ρ + (ext2-ρ + 0 0)) + +;;---------------------------- +;; Reverse-mode AD +;;---------------------------- + +(define ∇ + (λ (f theta) + (let ((wrt (map* dual* theta))) + (∇-once (f wrt) wrt)))) + +(define ∇¹ + (λ (f) + (λ xs + (let ((wrt (map* dual* xs))) + (∇-once (apply f wrt) wrt))))) + +(define ∇-once + (λ (y wrt) + (let ((σ (∇σ y (hasheq)))) + (map* (λ (d) + (hash-ref σ d 0.0)) + wrt)))) + +(define ∇σ + (λ (y σ) + (cond + ((dual-like? y) ((κ y) y (one-like (ρ y)) σ)) + ((list? y) (∇σ-list y σ)) + (else (printf "Unknown: ~a~%" y))))) + +(define ∇σ-list + (λ (y σ) + (cond + ((null? y) σ) + (else + (let ((σ-hat (∇σ (ref y 0) σ))) + (∇σ-list (refr y 1) σ-hat)))))) + +;;---------------------------- +;; General helpers +;;---------------------------- + +(define map* + (λ (f y) + (cond + ((dual-like? y) (f y)) + ((list? y) + (map (λ (yi) + (map* f yi)) + y)) + (else y)))) + +(define trace-print + (λ (v port) + (cond + ((dual? v) (trace-print (ρ v) port)) + (else (fprintf port "~a~%" v))))) + +(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) + +(include "test/test-A-autodiff.rkt") + +(provide + dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain + trace-print) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt new file mode 100644 index 0000000..879f28f --- /dev/null +++ b/lazy/autodiff/B-prims.rkt @@ -0,0 +1,72 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + +(define ρ-function + (λ (f) (f ρ-function))) + +(define ∇-function + (λ (f) (f ∇-function))) + +;;TODO: add more metadata to functions so that we know which function is being +;; passed to the extend functions. + +(define shape-fn + (λ (f) (f shape-fn))) + +(define prim1 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (λ (daf) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim1-dual ρ-fn ∇-fn daf)))))) + +(define prim1-dual + (λ (ρ-fn ∇-fn da) + (let ((ra (ρ da))) + (dual (ρ-fn ra) + (λ (d z σ) + (let ((ga (∇-fn ra z))) + ((κ da) da ga σ))))))) + +(define prim2 + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) + (λ ds + (let ((daf (ref ds 0))) + (cond + ((eq? daf ρ-function) ρ-fn) + ((eq? daf ∇-function) ∇-fn) + ((eq? daf shape-fn) shape) + (else (prim2-dual ρ-fn ∇-fn daf (ref ds 1)))))))) + +(define prim2-dual + (λ (ρ-fn ∇-fn da db) + (let ((ra (ρ da)) + (rb (ρ db))) + (dual (ρ-fn ra rb) + (λ (d z σ) + (let-values (((ga gb) (∇-fn ra rb z))) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat)))))))) + +;;---------------------------- +;; Dualized tensor op creators +;;---------------------------- +(define ext1 + (λ (f n) + (prim1 + (ext1-ρ (ρ-function f) n (shape-fn f)) + (ext1-∇ (∇-function f) n (shape-fn f)) + (shape-fn f)))) + +(define ext2 + (λ (f m n) + (prim2 + (ext2-ρ (ρ-function f) m n (shape-fn f)) + (ext2-∇ (∇-function f) m n (shape-fn f)) + (shape-fn f)))) + +(provide prim1 prim2 ext1 ext2) diff --git a/lazy/autodiff/C-dualized-tensor-ops.rkt b/lazy/autodiff/C-dualized-tensor-ops.rkt new file mode 100644 index 0000000..982d20b --- /dev/null +++ b/lazy/autodiff/C-dualized-tensor-ops.rkt @@ -0,0 +1,47 @@ +#lang racket + +(require "../tensors.rkt") +(require "A-autodiff.ss") + + +;;---------------------------- +;; Tensor ops, cleaned up. +;;---------------------------- + +(define d-rank + (lambda (t) + (rank (ρ t)))) + +(define d-shape + (λ (t) + (shape (ρ t)))) + +(define d-reshape + (λ (s t) + (cond + ((dual? t) + (dual (reshape s (ρ t)) + (κ t))) + (else (reshape s t))))) + +(define d-trefs + (λ (t b) + (trefs (ρ t) b))) + +(define d-tensor? + (λ (t) + (tensor? (ρ t)))) + +(define d-tlen + (λ (t) + (tlen (ρ t)))) + +(define d-ref + (λ (l i) + (ref l (ρ i)))) + +(define d-refr + (λ (l i) + (refr l (ρ i)))) + +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt new file mode 100644 index 0000000..c017981 --- /dev/null +++ b/lazy/autodiff/D-test-helpers.rkt @@ -0,0 +1,53 @@ +#lang racket + +(require "../tensors.rkt") +(require (only-in "../tensors/0-lazy.rkt" tp-force)) +(require "A-autodiff.ss") + +(require rackunit) + +(define forced-ρ + (λ (d) + (tp-force (ρ d)))) + +(define-binary-check (check-dual-equal? equal-wt? actual expected)) +(define-check (ρ-∇-checker fn args ans grads) + (let* ((y (tp-force (apply fn args))) + (g (tp-force (apply (∇¹ fn) args)))) + (cond + ((and (equal-wt? ans (ρ y)) + (equal-wt? grads (ρ g))) (void)) + ((equal-wt? ans (ρ y)) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (ρ g) grads))) + (else + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (ρ y) ans)))))) + +(define-syntax check-ρ-∇ + (syntax-rules () + [(check-both (fn args ...) ans grads) + (ρ-∇-checker fn (list args ...) ans grads)])) + +(define equal-wt? + (λ (a b) + (cond + ((and (tensor? a) (tensor? b)) + (tensor-equal? a b)) + ((dual? a) (equal-wt? (ρ a) b)) + ((dual? b) (equal-wt? a (ρ b))) + ((and (vector? a) (vector? b) + (= (vector-length a) (vector-length b))) + (vector-andmap equal-wt? a b)) + ((and (pair? a) (pair? b) + (= (length a) (length b))) + (andmap equal-wt? a b)) + (else (equal? a b))))) + + +(define vector-andmap + (λ (f v1 v2) + (for/fold ([s #t]) ([v1 v1][v2 v2]) + (and s (f v1 v2))))) + +(provide check-dual-equal? check-ρ-∇) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt new file mode 100644 index 0000000..6c85c70 --- /dev/null +++ b/lazy/autodiff/E-print.rkt @@ -0,0 +1,29 @@ +#lang racket + +(require "A-autodiff.rkt") +(require "../tensors/0-lazy.rkt") +(require "../../flat-tensors/ext-impl.rkt") + +(define max-tensor-print-length (make-parameter 5)) + +(define make-printable + (λ (y [max-length (max-tensor-print-length)]) + (cond + ((dual? y) (make-printable (ρ y))) + ((and (not (scalar? y)) (tensor? y)) + (make-printable-tp y max-length)) + ((list? y) + (map (λ (le) (make-printable le max-length)) y)) + ((vector? y) + (vector-map (λ (ve) (make-printable ve max-length)) y)) + (else y)))) + +(define make-printable-tp + (λ (y [max-length (max-tensor-print-length)]) + (make-printable-flat (tp-force y) max-length))) + + +(include "test/test-E-print.rkt") + +(provide max-tensor-print-length + make-printable) diff --git a/lazy/autodiff/test/test-A-autodiff.rkt b/lazy/autodiff/test/test-A-autodiff.rkt new file mode 100644 index 0000000..b453b3e --- /dev/null +++ b/lazy/autodiff/test/test-A-autodiff.rkt @@ -0,0 +1,15 @@ +(module+ test + (require rackunit) + (let ((k0 end-of-chain)) + (let ((dual0 0) + (dual1 (dual 1 k0))) + + (check-equal? dual1 (dual 1 k0)) + (check-true (dual? dual1)) + (check-false (dual? 1)) + (check-equal? (ρ dual1) 1) + (check-equal? (ρ dual0) 0) + (check-equal? (κ dual1) k0) + + (check-equal? (map* (λ (d) (ρ d)) (∇-once dual1 (list dual0 dual1))) + '(0.0 1.0))))) diff --git a/lazy/autodiff/test/test-E-print.rkt b/lazy/autodiff/test/test-E-print.rkt new file mode 100644 index 0000000..91fde6c --- /dev/null +++ b/lazy/autodiff/test/test-E-print.rkt @@ -0,0 +1,71 @@ +(module+ test + (require rackunit) + (require "../tensors.rkt") + + (define long-tensor + (tensor 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15)) + + (define dualized-long-tensor + (dual long-tensor end-of-chain)) + + (define deep-tensor + (tensor long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor + long-tensor long-tensor long-tensor long-tensor long-tensor)) + + (define deeper-tensor + (tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor + deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) + + (check-equal? (make-printable long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable deep-tensor 3) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...))) + + (check-equal? (make-printable deeper-tensor 3) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + '...))) + (parameterize ((max-tensor-print-length 3)) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) + (list + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor + (list + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + (fake-tensor + (list (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1 2 3 ...)) + '...)) + '...)))))) diff --git a/lazy/ext-ops.rkt b/lazy/ext-ops.rkt new file mode 100644 index 0000000..d3dff13 --- /dev/null +++ b/lazy/ext-ops.rkt @@ -0,0 +1,33 @@ +#lang racket + +(require "ext-ops/A-scalar-ops.rkt") +(require "ext-ops/B-comparators.rkt") +(require "ext-ops/C-star-2-1.rkt") +(require "ext-ops/D-sum.rkt") +(require "ext-ops/E-argmax.rkt") +(require "ext-ops/F-max.rkt") +(require "ext-ops/G-correlate.rkt") + +(provide d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) + +(provide d*-2-1 *-2-1-ρ) + +(provide sum-1 d-sum sum-ρ d-sum-cols sum-cols-ρ) + +(provide argmax-1 d-argmax argmax-ρ) + +(provide max-1 d-max max-ρ) + +(provide correlate-ρ d-correlate) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt new file mode 100644 index 0000000..9c46948 --- /dev/null +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -0,0 +1,133 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) +(require "../autodiff.rkt") + +(define +-0-0 + (prim2 + + (λ (a b z) + (values z z)))) + +(define --0-0 + (prim2 - + (λ (a b z) + (values z (- z))))) + +(define *-0-0 + (prim2 * + (λ (a b z) + (values (* b z) (* a z))))) + +(define /-0-0 + (prim2 / + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))))) + +(define expt-0-0 + (prim2 expt + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))))) + +(define exp-0 + (prim1 exp + (λ (a z) + (* z (exp a))))) + +(define log-0 + (prim1 log + (λ (a z) + (* z (/ 1 a))))) + +(define abs-0-ρ + (λ (x) + (cond + ((< x 0) (* -1 x)) + (else x)))) + +(define abs-0-∇ + (λ (x z) + (cond + ((< x 0) (- z)) + (else z)))) + +(define abs-0 + (prim1 abs-0-ρ abs-0-∇)) + +(define rectify-0-ρ + (λ (s) + (cond + ((< s 0.0) 0.0) + (else s)))) + +(define rectify-0-∇ + (λ (s z) + (cond + ((< s 0.0) 0.0) + (else z)))) + +(define rectify-shape + (λ (s) s)) + +(define rectify-0 + (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + +;;------------------------------------ +;; differentiable extended functions. +;;------------------------------------ + +(define d* (ext2 *-0-0 0 0)) +(define d+ (ext2 +-0-0 0 0)) +(define d- (ext2 --0-0 0 0)) +(define d/ (ext2 /-0-0 0 0)) +(define d-expt (ext2 expt-0-0 0 0)) + +(define d-exp (ext1 exp-0 0)) +(define d-log (ext1 log-0 0)) +(define d-abs (ext1 abs-0 0)) +(define d-rectify (ext1 rectify-0 0)) + +(define d-sqrt + (λ (a) + (d-expt a 1/2))) + +(define d-sqr + (λ (x) + (d* x x))) + +;;------------------------------------ +;; non-differentiable extended functions. +;;------------------------------------ + +(define *-ρ (ext2-ρ * 0 0)) +(define +-ρ (ext2-ρ + 0 0)) +(define --ρ (ext2-ρ - 0 0)) +(define /-ρ (ext2-ρ / 0 0)) +(define expt-ρ (ext2-ρ expt 0 0)) + +(define exp-ρ (ext1-ρ exp 0)) +(define log-ρ (ext1-ρ log 0)) +(define abs-ρ (ext1-ρ abs-0-ρ 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) + +(define sqrt-ρ + (λ (a) + (expt-ρ a 1/2))) + +(define sqr-ρ + (λ (x) + (*-ρ x x))) + +(include "test/test-A-scalar-ops.rkt") + +(provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 + + d+ d- d* d/ + d-expt d-exp d-log d-abs + d-rectify d-sqrt d-sqr + + +-ρ --ρ *-ρ /-ρ + expt-ρ exp-ρ log-ρ abs-ρ + rectify-ρ sqrt-ρ sqr-ρ) diff --git a/lazy/ext-ops/B-comparators.rkt b/lazy/ext-ops/B-comparators.rkt new file mode 100644 index 0000000..c42a2cf --- /dev/null +++ b/lazy/ext-ops/B-comparators.rkt @@ -0,0 +1,85 @@ +#lang racket + +(require "../autodiff.rkt") + +;;---------------------------- +;; Boolean comparators +;;---------------------------- + +(define comparator + (λ (f) + (λ (da db) + (f (ρ da) (ρ db))))) + +(define =-0-0 + (comparator =)) + +(define <-0-0 + (comparator <)) + +(define <=-0-0 + (comparator <=)) + +(define >-0-0 + (comparator >)) + +(define >=-0-0 + (comparator >)) + +;;---------------------------- +;; Tensorized comparators +;;---------------------------- + +(define comparator-ρ + (λ (f) + (λ (da db) + (cond + ((f (ρ da) (ρ db)) 1.0) + (else 0.0))))) + +(define comparator-∇ + (λ (f) + (λ (da db z) + (cond + ((f (ρ da) (ρ db)) (values z z)) + (else (values 0.0 0.0)))))) + +(define comparator-shape + (λ (f) + (λ (sa sb) + sa))) + +(define comparator-prim + (λ (f) + (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + +(define extended-comparator + (λ (f) + (ext2 (comparator-prim f) 0 0))) + +(define =-1 + (extended-comparator =)) + +(define <-1 + (extended-comparator <)) + +(define >-1 + (extended-comparator >)) + +(define <=-1 + (extended-comparator <=)) + +(define >=-1 + (extended-comparator >=)) + +(define != + (λ (a b) + (not (= a b)))) + +(define !=-1 + (extended-comparator !=)) + +(include "test/test-B-comparators.rkt") + +(provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/ext-ops/C-star-2-1.rkt b/lazy/ext-ops/C-star-2-1.rkt new file mode 100644 index 0000000..75ecaee --- /dev/null +++ b/lazy/ext-ops/C-star-2-1.rkt @@ -0,0 +1,43 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext2-ρ)) +(require "../autodiff.rkt") + +(define *-2-1-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (vector-set! v-out (+ i-out i) + (* (vector-ref v0 (+ i0 i)) + (vector-ref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (let ((a (vector-ref v0 (+ i0 i))) + (b (vector-ref v1 (+ i1 (modulo i stride1)))) + (z (vector-ref vz (+ iz i)))) + (vector-set! g0 (+ i0 i) + (+ (vector-ref g0 (+ i0 i)) (* z b))) + (vector-set! g1 (+ i1 (modulo i stride1)) + (+ (vector-ref g1 (+ i1 (modulo i stride1))) (* z a))))))) + +(define *-2-1-shape + (λ (s t) + s)) + +(define *-2-1 + (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape)) + +(define d*-2-1 + (ext2 *-2-1 2 1)) + +(define *-2-1-ρ + (ext2-ρ *-2-1-base-ρ 2 1 *-2-1-shape)) + +(include "test/test-C-star-2-1.rkt") + +(provide *-2-1-ρ d*-2-1) diff --git a/lazy/ext-ops/D-sum.rkt b/lazy/ext-ops/D-sum.rkt new file mode 100644 index 0000000..046131f --- /dev/null +++ b/lazy/ext-ops/D-sum.rkt @@ -0,0 +1,67 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define sum-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vector-set! v-out i-out + (for/fold ([sum 0.0]) ([i (in-range i0 (+ i0 stride0))]) + (+ sum (vector-ref v0 i)))))) + +(define sum-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vector-ref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vector-set! g0 i + (+ (vector-ref g0 i) z)))))) + +(define sum-shape + (λ (st) + (refr st 1))) + +(define sum-1 + (prim1 sum-1-ρ sum-1-∇ sum-shape)) + +(define d-sum + (ext1 sum-1 1)) + +(define sum-ρ + (ext1-ρ sum-1-ρ 1 sum-shape)) + +(provide d-sum sum-ρ) + +(define sum-cols-2-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (for ((i (in-range 0 stride-out))) + (vector-set! v-out (+ i i-out) + (for/fold ([sum 0.0]) ([j (in-range i0 (+ i0 stride0) stride-out)]) + (+ sum (vector-ref v0 (+ j i)))))))) + +(define sum-cols-2-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (for ((i (in-range 0 stride-z))) + (for ([j (in-range i0 (+ i0 stride0) stride-z)]) + (vector-set! g0 (+ i j) + (+ (vector-ref g0 (+ i j)) (vector-ref vz (+ i iz)))))))) + +(define sum-cols-shape + (λ (s) + (refr s 1))) + +(define sum-cols-2 + (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape)) + +(define d-sum-cols + (ext1 sum-cols-2 2)) + +(define sum-cols-ρ + (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape)) + +(include "test/test-D-sum.rkt") + +(provide sum-1 d-sum-cols sum-cols-ρ) diff --git a/lazy/ext-ops/E-argmax.rkt b/lazy/ext-ops/E-argmax.rkt new file mode 100644 index 0000000..ad4aefe --- /dev/null +++ b/lazy/ext-ops/E-argmax.rkt @@ -0,0 +1,40 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define argmax-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vector-set! v-out i-out + (for/fold ([max 0.0] + [max-i -1] #:result max-i) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vector-ref v0 i))) + (cond + ((> v max) (values v (+ (- i i0) 0.0))) + (else (values max max-i)))))))) + +(define argmax-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vector-ref vz iz))) + (for ([i (in-range i0 (+ i0 stride0))]) + (vector-set! g0 i 0.0))))) + +(define argmax-shape + (λ (st) + '())) + +(define argmax-1 + (prim1 argmax-1-ρ argmax-1-∇ argmax-shape)) + +(define d-argmax + (ext1 argmax-1 1)) + +(define argmax-ρ + (ext1-ρ argmax-1-ρ 1 argmax-shape)) + +(include "test/test-E-argmax.rkt") + +(provide argmax-1 d-argmax argmax-ρ) diff --git a/lazy/ext-ops/F-max.rkt b/lazy/ext-ops/F-max.rkt new file mode 100644 index 0000000..5a54bd5 --- /dev/null +++ b/lazy/ext-ops/F-max.rkt @@ -0,0 +1,48 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ)) +(require "../autodiff.rkt") + +(define max-1-ρ + (λ (v0 i0 stride0 + v-out i-out stride-out) + (vector-set! v-out i-out + (for/fold ([max 0.0]) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vector-ref v0 i))) + (cond + ((> v max) v) + (else max))))))) + +(define max-1-∇ + (λ (g0 v0 i0 stride0 + vz iz stride-z) + (let ((z (vector-ref vz iz))) + (for/fold ([max -inf.0] + [max-i -1] #:result + (for ([i (in-range i0 (+ i0 stride0))]) + (cond + ((= i (+ i0 max-i)) (vector-set! g0 i z)) + (else (vector-set! g0 i 0.0))))) + ([i (in-range i0 (+ i0 stride0))]) + (let ((v (vector-ref v0 i))) + (cond + ((> v max) (values v (- i i0))) + (else (values max max-i)))))))) + +(define max-shape + (λ (st) + (cdr st))) + +(define max-1 + (prim1 max-1-ρ max-1-∇ max-shape)) + +(define d-max + (ext1 max-1 1)) + +(define max-ρ + (ext1-ρ max-1-ρ 1 max-shape)) + +(include "test/test-F-max.rkt") + +(provide max-1 d-max max-ρ) diff --git a/lazy/ext-ops/G-correlate.rkt b/lazy/ext-ops/G-correlate.rkt new file mode 100644 index 0000000..4cbe0e5 --- /dev/null +++ b/lazy/ext-ops/G-correlate.rkt @@ -0,0 +1,94 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext2-ρ len)) +(require "../autodiff.rkt") + +;; Correlation is written taking into account how ext2 works +;; Ext2 is responsible for producing the i-out'th output from +;; v0[i0] and v1[i1], we take advantage of this. The shape constants +;; n b m d are pre-calculated the striding constants nd md and qd +;; are calculated. + +(define correlate-3-1-ρ + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (vector-set! v-out (+ i-out i) + (for/fold ([sum 0.0]) ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (cond + ((and (>= bi i1-min) (< bi i1-max)) + (let ((a (vector-ref v0 ai)) + (b (vector-ref v1 bi))) + (+ sum (* a b)))) + (else sum)))))))))) + +(define correlate-3-1-∇ + (λ (nd md qd) + (λ (g0 g1 + v0 i0 bmd + v1 i1 d + vz iz b) + (let* ((i1-min (- i1 (modulo i1 nd))) + (i1-max (+ i1-min nd))) + (for ((i (in-range 0 b))) + (let ((z (vector-ref vz (+ iz i)))) + (for ([j (in-range 0 md)]) + (let ((ai (+ i0 (* i md) j)) + (bi (- (+ i1 j) qd))) + (when (and (>= bi i1-min) (< bi i1-max)) + (let ((a (vector-ref v0 ai)) + (b (vector-ref v1 bi))) + (vector-set! g0 ai + (+ (vector-ref g0 ai) (* z b))) + (vector-set! g1 bi + (+ (vector-ref g1 bi) (* z a))))))))))))) + +(define correlate-shape + (λ (bmd nd) + (list (car bmd)))) + +(define correlate-3-1 + (λ (nd md qd) + (prim2 + (correlate-3-1-ρ nd md qd) + (correlate-3-1-∇ nd md qd) + correlate-shape))) + +(define d-correlate + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2 (correlate-3-1 nd md qd) 3 1) bank signal)))) + +(define correlate-ρ + (λ (bank signal) + (let* ((b-m-d (last 3 (shape (ρ bank)))) + (n-d (last 2 (shape (ρ signal)))) + (d (ref n-d 1)) + (nd (* d (ref n-d 0))) + (m (ref b-m-d 1)) + (q (/ (- m 1) 2)) ;; This is the padding. + (qd (* q d)) + (md (* m d))) + ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + bank signal)))) + +(define last + (λ (n s) + (refr s (- (len s) n)))) + +(include "test/test-G-correlate.rkt") + +(provide d-correlate correlate-ρ) diff --git a/lazy/ext-ops/test/test-A-scalar-ops.rkt b/lazy/ext-ops/test/test-A-scalar-ops.rkt new file mode 100644 index 0000000..2c13e39 --- /dev/null +++ b/lazy/ext-ops/test/test-A-scalar-ops.rkt @@ -0,0 +1,122 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + ;; Check basic numericals + (let ((a 2) + (b 3)) + (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) + (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) + (check-ρ-∇ (d/ a b) + 2/3 + '(0.3333333333333333 -0.2222222222222222)) + (check-ρ-∇ (d-exp a) (exp 2) (list (exp 2))) + (check-ρ-∇ (d-log a) (log 2) (list 0.5)) + (check-ρ-∇ (d-expt a b) 8 + (list 12.0 5.545177444479562)) + (check-ρ-∇ (d-sqrt a) + (sqrt 2) + (list 0.3535533905932738)) + + (let* ((first-derivative ((∇¹ d-sqr) b))) + (check-dual-equal? first-derivative (list 6.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* 2)) + (d (dual* 2))) + (check-dual-equal? + ((∇¹ z) c d) + (list 2.5 2.0))) + + (check-dual-equal? ((∇¹ z) a a) (list 2.5 2.0))) + + ;; Check numericals with vector-duals + + (let ((a (tensor 2.0 3.0 4.0)) + (b (tensor 3.0 8.0 9.0))) + (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) + (check-ρ-∇ (d- a b) (tensor -1 -5 -5) (list (tensor 1.0 1.0 1.0) (tensor -1.0 -1.0 -1.0))) + (check-ρ-∇ (d* a b) (tensor 6.0 24.0 36.0) (list (tensor 3.0 8.0 9.0) (tensor 2.0 3.0 4.0))) + (check-ρ-∇ (d/ a b) + (tensor (/ 2.0 3.0) (/ 3.0 8.0) (/ 4.0 9.0)) + (list (tensor 0.3333333333333333 0.125 0.1111111111111111) + (tensor -0.2222222222222222 -0.046875 -0.04938271604938271))) + (check-ρ-∇ (d-exp a) (tensor (exp 2.0) (exp 3.0) (exp 4.0)) + (list (tensor (exp 2.0) (exp 3.0) (exp 4.0)))) + (check-ρ-∇ (d-log a) (tensor (log 2.0) (log 3.0) (log 4.0)) + (list (tensor 0.5 (/ 1 3.0) (/ 1 4.0)))) + (check-ρ-∇ (d-expt a b) + (tensor (expt 2.0 3.0) (expt 3.0 8.0) (expt 4.0 9.0)) + (list (tensor 12.0 17496.0 589824.0) + (tensor 5.545177444479562 7207.9952259514685 363408.7490014126))) + + (check-ρ-∇ (d-sqrt a) + (tensor (sqrt 2.0) (sqrt 3.0) (sqrt 4.0)) + (list (tensor 0.3535533905932738 0.28867513459481287 0.25))) + + (check-ρ-∇ (d-sqr b) (tensor 9.0 64.0 81.0) (list (tensor 6.0 16.0 18.0))) + + (define z + (lambda (x y) + (d+ (d-log x) (d* x y)))) + + (let ((c (dual* (tensor 2.0 3.0 4.0))) + (d (dual* (tensor 2.0 3.0 4.0)))) + (check-dual-equal? + ((∇¹ z) c d) + (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + (check-dual-equal? ((∇¹ z) a a) (list (tensor 2.5 3.3333333333333335 4.25) (tensor 2.0 3.0 4.0)))) + + ;; Check numericals with lists + (let ((x (dual* 3)) + (y (dual* 2))) + (let ((f d*)) + (check-ρ-∇ (d* x y) 6 '(2.0 3.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list (f m m) (f n n)))) + x y) + '(6.0 4.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (list + (list (f m m) (f n n)) + (list (f m m) (f n n))))) + x y) + '(12.0 8.0)) + (check-dual-equal? + ((∇¹ (λ (m n) + (map (λ (m) (f m m)) + (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list m n)))) + x y) + '(6.0 4.0)) + + (check-dual-equal? + ((∇¹ (λ (m n) + (map f (list m n) (list n m)))) + x y) + '(4.0 6.0)))) + + (let ((a 7) + (b (tensor 13))) + (check-ρ-∇ (d+ a b) (tensor 20) (list 1.0 (tensor 1.0))) + (check-ρ-∇ (d* a b) (tensor 91) (list 13.0 (tensor 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13) (list 0.07692 (tensor -0.04142)))) + + (let ((a 7) + (b (tensor 13 15))) + (check-ρ-∇ (d+ a b) (tensor 20 22) (list 2.0 (tensor 1.0 1.0))) + (check-ρ-∇ (d* a b) (tensor 91 105) (list 28.0 (tensor 7.0 7.0))) + (check-ρ-∇ (d/ a b) (tensor 7/13 7/15) + (list 0.14358 (tensor -0.04142 -0.03111))))) diff --git a/lazy/ext-ops/test/test-B-comparators.rkt b/lazy/ext-ops/test/test-B-comparators.rkt new file mode 100644 index 0000000..9f3fdf5 --- /dev/null +++ b/lazy/ext-ops/test/test-B-comparators.rkt @@ -0,0 +1,12 @@ +(module+ test + (require rackunit) + (let ((a 2) + (b 3)) + (check-true (<-0-0 a b)) + (check-false (>-0-0 a b)) + (check-true (<=-0-0 a b)) + (check-false (>=-0-0 a b)) + (check-false (=-0-0 a b)) + (check-true (=-0-0 a a)) + (check-true (zero? 0)) + (check-false (zero? a)))) diff --git a/lazy/ext-ops/test/test-C-star-2-1.rkt b/lazy/ext-ops/test/test-C-star-2-1.rkt new file mode 100644 index 0000000..bbb2c8b --- /dev/null +++ b/lazy/ext-ops/test/test-C-star-2-1.rkt @@ -0,0 +1,24 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (tensor (tensor 36 52 70 90) (tensor 84 104 126 150))) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/lazy/ext-ops/test/test-D-sum.rkt b/lazy/ext-ops/test/test-D-sum.rkt new file mode 100644 index 0000000..e77ab0a --- /dev/null +++ b/lazy/ext-ops/test/test-D-sum.rkt @@ -0,0 +1,53 @@ +(module+ test + (require rackunit) + (require "C-star-2-1.ss") + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) + + (let ((a (tensor 3 4 5))) + (check-ρ-∇ (d-sum a) 12 + (list (tensor 1.0 1.0 1.0)))) + + (let ((a (tensor (tensor 3 4 5) + (tensor 6 7 8)))) + (check-dual-equal? (d-sum a) (tensor 12 21)) + (check-dual-equal? ((∇¹ (λ (b) (d-sum (d* b b)))) a) + (list (tensor (tensor 6.0 8.0 10.0) + (tensor 12.0 14.0 16.0))))) + + (define dot-product + (λ (a b) + (d-sum (d*-2-1 a b)))) + + (define sse + (λ (a b) + (d-sum (d-sqr (d- a b))))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor 2 3 4 5))) + + (check-ρ-∇ (dot-product a b) + (tensor 68 124) + (list (tensor (tensor 2.0 3.0 4.0 5.0) + (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + + (check-ρ-∇ (sse a b) + (tensor 4 100) + (list (tensor (tensor 2.0 2.0 2.0 2.0) + (tensor 10.0 10.0 10.0 10.0)) + (tensor -12.0 -12.0 -12.0 -12.0)))) + + (let ((a (tensor (tensor 3 4 5 6) + (tensor 7 8 9 10))) + (b (tensor (tensor 2 3 4 5) + (tensor 12 13 14 15)))) + + (check-ρ-∇ (dot-product a b) + (tensor (tensor 68 124) + (tensor 248 464)) + (list (tensor (tensor 14.0 16.0 18.0 20.0) + (tensor 14.0 16.0 18.0 20.0)) + (tensor (tensor 10.0 12.0 14.0 16.0) + (tensor 10.0 12.0 14.0 16.0)))))) diff --git a/lazy/ext-ops/test/test-E-argmax.rkt b/lazy/ext-ops/test/test-E-argmax.rkt new file mode 100644 index 0000000..f72819e --- /dev/null +++ b/lazy/ext-ops/test/test-E-argmax.rkt @@ -0,0 +1,17 @@ +(module+ test + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (d-argmax y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0)))) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-argmax y) (tensor 2.0 1.0 0.0 3.0) + (list + (tensor (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 0.0)))))) diff --git a/lazy/ext-ops/test/test-F-max.rkt b/lazy/ext-ops/test/test-F-max.rkt new file mode 100644 index 0000000..01ab1a5 --- /dev/null +++ b/lazy/ext-ops/test/test-F-max.rkt @@ -0,0 +1,10 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) + (tensor 0.0 1.0 0.0 0.0) + (tensor 1.0 0.0 0.0 0.0) + (tensor 0.0 0.0 0.0 1.0)))) + (check-ρ-∇ (d-max y) (tensor 1.0 1.0 1.0 1.0) + (list y)))) diff --git a/lazy/ext-ops/test/test-G-correlate.rkt b/lazy/ext-ops/test/test-G-correlate.rkt new file mode 100644 index 0000000..417723c --- /dev/null +++ b/lazy/ext-ops/test/test-G-correlate.rkt @@ -0,0 +1,118 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor ext2-∇ check-tensor-equal?)) + + ;; for testing b = 4 + ;; m = 3 + ;; d = 2 + + ;; signal length n = 6 + + ;; (1 2) (3 4) (5 6) (7 8) (9 10) (11 12) + ;; (1 2) (3 4) (5 6) + ;; (7 8) (9 10) (11 12) + ;; (13 14) (15 16) (17 18) + ;; (19 20) (21 22) (23 24) + + ;; Signal is (n d) + (define signal (tensor (tensor 1 2) + (tensor 3 4) + (tensor 5 6) + (tensor 7 8) + (tensor 9 10) + (tensor 11 12))) + + (define bank (tensor (tensor + (tensor 1 2) + (tensor 3 4) + (tensor 5 6)) + (tensor + (tensor 7 8) + (tensor 9 10) + (tensor 11 12)) + (tensor + (tensor 13 14) + (tensor 15 16) + (tensor 17 18)) + (tensor + (tensor 19 20) + (tensor 21 22) + (tensor 23 24)))) + + (define corr-ρ + (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape)) + + (define corr-∇ + (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape)) + + (check-tensor-equal? (corr-ρ bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let-values (((filter-∇ signal-∇) + (corr-∇ bank signal (tensor (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0))))) + (check-tensor-equal? filter-∇ + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-tensor-equal? signal-∇ + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0)))) + + (check-dual-equal? (d-correlate bank signal) + ;; Should be of size nb + (tensor (tensor 50.0 110.0 170.0 230.0) + (tensor 91.0 217.0 343.0 469.0) + (tensor 133.0 331.0 529.0 727.0) + (tensor 175.0 445.0 715.0 985.0) + (tensor 217.0 559.0 901.0 1243.0) + (tensor 110.0 362.0 614.0 866.0))) + + (let ((gs ((∇¹ d-correlate) bank signal))) + (check-dual-equal? (car gs) + (tensor + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)) + (tensor (tensor 25.0 30.0) + (tensor 36.0 42.0) + (tensor 35.0 40.0)))) + (check-dual-equal? (cadr gs) + ;; Should be of size nb + (tensor (tensor 88.0 96.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 144.0 156.0) + (tensor 104.0 112.0))))) diff --git a/lazy/no-duals-no-overrides.rkt b/lazy/no-duals-no-overrides.rkt new file mode 100644 index 0000000..6163435 --- /dev/null +++ b/lazy/no-duals-no-overrides.rkt @@ -0,0 +1,28 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-duals.rkt b/lazy/no-duals.rkt new file mode 100644 index 0000000..062be42 --- /dev/null +++ b/lazy/no-duals.rkt @@ -0,0 +1,28 @@ +#lang racket/base + +(module+ test + (require rackunit)) + +(require "tensors.rkt") +(require "ext-ops.rkt") + +(define scalar? number?) + +(provide + ;; From tensors + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + scalar? tensor? rank shape reshape trefs + + ;; From ext-ops + (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) + (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate)) + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-overrides.rkt b/lazy/no-overrides.rkt new file mode 100644 index 0000000..22fb210 --- /dev/null +++ b/lazy/no-overrides.rkt @@ -0,0 +1,37 @@ +#lang racket/base + +(require + (except-in "tensors.rkt" + rank shape reshape trefs tensor? tlen ref refr)) + +(require "autodiff.rkt") +(require "ext-ops.rkt") + +(provide + len ref refr + + tref tlen list->tensor tensor build-tensor + + ext1-ρ ext2-ρ ext1-∇ ext2-∇ + + dual dual? ρ κ ∇ ∇¹ + + ext1 ext2 prim1 prim2 + + scalar? tensor? rank shape reshape trefs + + trace-print check-dual-equal? check-ρ-∇ + + d+ d- d* d/ d-rectify + d-exp d-log d-expt d-sqrt + d-sum d-abs d*-2-1 d-argmax + d-max d-sum-cols d-correlate + + +-ρ --ρ *-ρ /-ρ rectify-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ + sum-ρ abs-ρ *-2-1-ρ argmax-ρ + max-ρ sum-cols-ρ correlate-ρ + + + =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 + =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt new file mode 100644 index 0000000..9461f5b --- /dev/null +++ b/lazy/tensors.rkt @@ -0,0 +1,19 @@ +#lang racket +(require "tensors/0-lazy.rkt") +(require "tensors/A-equality.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide tolerance tensor-equal? check-tensor-equal?) + +(provide len ref refr) +(provide tref tlen list->tensor tensor build-tensor trefs) + +(provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) + +;; TODO: figure out why was this exported in flat-tensors +;;(provide flat? flat-shape flat-store flat-offset size-of strides) + +;; These will get overriden by duals +(provide tensor?) +(provide rank shape reshape size-of) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt new file mode 100644 index 0000000..a6cedad --- /dev/null +++ b/lazy/tensors/0-lazy.rkt @@ -0,0 +1,494 @@ +#lang racket +(require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + +;; tensor computations +(struct tcomp ()) +;; TODO: figure out if removing tcom-tensor is a good idea +(struct tcomp-tensor tcomp (t-shape t-flat) #:transparent) +(struct tcomp-list->tpromise-list tcomp (lst) #:transparent) +(struct tcomp-tp-map tcomp (f tp) #:transparent) +(struct tcomp-build-tpromise tcomp (s f) #:transparent) +(struct tcomp-tp-trefs tcomp (forced b) #:transparent) +(struct tcomp-ext2-∇ tcomp (b forcer) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp flat-f) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u flat-f) #:transparent) +(struct tcomp-ext1-ρ tcomp (tp flat-f) #:transparent) +(struct tcomp-reshape tcomp (s tp) #:transparent) + +(struct tpromise ((tensor #:mutable) shape) + #:guard + (λ (tensor shape name) + (unless (or (flat? tensor) (tcomp? tensor)) + (error 'make-tpromise + (string-append + "First argument must be either a" + " tcomp or a flat tensor. Got ~a") + tensor)) + (unless ((listof positive-integer?) shape) + (error 'make-tpromise + (string-append + "Second argument must be a list" + " of positive integers. Got ~a") + shape)) + (values tensor shape)) + #:transparent) + +(define scalar? number?) + +(define tensor + (λ args + (ensure-shape args) + (let ([inner-flat (tensor-inner-flat args)]) + (tpromise inner-flat (flat:shape inner-flat))))) + +(define tensor-inner-flat + (λ (args) + (cond + [(number? (car args)) (apply flat:tensor args)] + [else (merge-flats (map tp-force args))]))) + +(define ensure-shape + (λ (args) + (unless (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((tpromise? (car args)) + (let ((s (tp-shape (car args)))) + (andmap (λ (t) + (and (tpromise? t) + (equal? (tp-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Mismatched shapes: ~a~%" + args)))) + +(define ensure-flat + (λ (v) + (cond + ((scalar? v) (flat '() (vec v) 0)) + (else v)))) + +;(-> tpromise (U flat scalar)) +(define tp-force + (lambda (tp (print? #f)) + (when print? + (printf "~n####PP tensor: ") + (pretty-print tp)) + (let ([res + (match tp + [(tpromise t-tcomp _) + #:when (tcomp? t-tcomp) + (tcomp-force t-tcomp)] + [(tpromise t _) + #:when (or (flat? t) (scalar? t)) t] + + ;; NOTE: This case runs when we use tp-scalarize to turn + ;; the tensor to a scalar + [_ #f])]) + (cond + [res (set-tpromise-tensor! tp res) + res] + [else tp])))) + +(define tcomp-force + (λ (tc) + (match tc + #;[(tcomp-tensor t-shape t-flat) + (flat t-shape t-flat 0)] + [(tcomp-list->tpromise-list lst) + (flat:list->tensor + (map (λ (l) (tp-force l #f)) lst))] + [(tcomp-tp-map f tp) + (let* ([flat-vec (tp-force tp #f)] + [store (flat-store flat-vec)] + [shape (flat-shape flat-vec)] + [offset (flat-offset flat-vec)]) + (flat shape (vector-map f store) + offset))] + [(tcomp-build-tpromise s f) + (flat:build-tensor s f)] + [(tcomp-tp-trefs forced b) + (flat:trefs forced b)] + [(tcomp-ext2-∇ b forcer) + (let ([v (unbox b)]) + (cond + ((eqv? v 'uncalculated) + (forcer) + (unbox b)) + (else v)))] + [(tcomp-ext1-∇ tp zp flat-f) + (let ([t (tp-force tp #f)] + [z (tp-force zp #f)]) + (scalarize (flat-f (ensure-flat t) (ensure-flat z))))] + [(tcomp-ext2-ρ tp-t tp-u flat-f) + (let ([t (tp-force tp-t #f)] + [u (tp-force tp-u #f)]) + (scalarize (flat-f (ensure-flat t) (ensure-flat u))))] + [(tcomp-ext1-ρ tp flat-f) + (let ([t (tp-force tp #f)]) + (let ([res (scalarize (flat-f t))]) + res))] + [(tcomp-reshape s tp) + (let ([t (tp-force tp #f)]) + (flat s (flat-store t) (flat-offset t)))]))) + +(define tp-force-ref + (λ (tp i) + (flat:tref (tp-force tp) i))) + +(define bounded-idx*^ + (λ (shape idx*) + (match `(,shape ,idx*) + [`(,_ ()) #t] + [`(() ,_) #f] + [`((,sa . ,sd) (,ia . ,id*)) + (and (< ia sa) + (>= ia 0) + (bounded-idx*^ sd id*))]))) + +(define bounded-idx*? + (λ (tp idx*) + (bounded-idx*^ (tpromise-shape tp) idx*))) + +(define tp-tref + (lambda (tp i) + (cond + [(and (bounded-idx*? tp (list i)) + (flat? (tp-force-ref tp i))) + (tpromise (tp-force-ref tp i) + (flat-shape (tp-force-ref tp i)))] + [(bounded-idx*? tp (list i)) + (tp-force-ref tp i)] + [else (error 'exn:tp-tref + (string-append + "Index out of bounds. ~a " + "greater than or equals length ~a~%") + i + (tp-tlen tp))]))) + +(define tp-tlen + (λ (tp) + (car (tpromise-shape tp)))) + +(define tp-shape + (lambda (v) + (cond + [(tpromise? v) (tpromise-shape v)] + [else (flat:shape v)]))) + +(define list->tpromise + (λ (lst) + (cond + [(null? lst) + (error 'list->ltensor "No elements found")] + [else + (tpromise (tcomp-list->tpromise-list lst) + `(,(length lst) + . ,(tp-shape + (car lst))))]))) + +(define tp-tmap + (λ (f tp) + (struct-copy + tpromise tp + (tensor + (tcomp-tp-map f tp))))) + +(define build-tpromise + (λ (s f) + (tpromise (tcomp-build-tpromise s f) s))) + +(define tp-trefs + (λ (tp b) + (cond + [(ormap (λ (i) + (>= i + (car (tpromise-shape tp)))) + b) + (error 'tp-trefs + "An index was out of bounds")] + [else + (let ([forced (tp-force tp)]) + (tpromise (tcomp-tp-trefs forced b) + `(,(length b) + . ,(cdr (flat-shape forced)))))]))) + +(define tp-ext1-ρ + (λ (f + m + [shape-fn scalar-shape] + [context 'lazy-ext1]) + (let ((flat-f + (flat-ext1-ρ (flat-function-maker1 f m) + m shape-fn context))) + (λ (tp) + (cond + [(scalar? tp) (f tp)] + [(and (tpromise? tp) + (null? (tpromise-shape tp))) + (f (tp-force tp))] + [else + (tpromise + (tcomp-ext1-ρ tp flat-f) + (merge-shapes + (tpromise-shape tp) + m + (shape-fn + (min-shape m (tpromise-shape tp)))))]))))) + +(define tp-ext2-ρ + (λ (f + m + n + [shape-fn scalar-shape] + [context 'raw-ext2]) + (let ((flat-f + (flat-ext2-ρ (flat-function-maker2 f m n) + m n shape-fn #;context))) + (λ (tp-t tp-u) + (cond + ((and (number? tp-t) (number? tp-u)) + (f tp-t tp-u)) + [(and (tpromise? tp-t) (tpromise? tp-u) + (null? (tpromise-shape tp-t)) + (null? (tpromise-shape tp-u))) + (f (tp-force tp-t) (tp-force tp-u))] + [else + (let* ([s0 (tp-shape tp-t)] + [s1 (tp-shape tp-u)] + [sf0 (min-shape m s0)] + [sf1 (min-shape n s1)] + [sf-out (shape-fn sf0 sf1)]) + (tpromise + (tcomp-ext2-ρ tp-t tp-u flat-f) + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))]))))) + +(define scalarize + (λ (t) + (cond + ((null? (flat-shape t)) + (vref (flat-store t) 0)) + (else t)))) + +(define tp-scalarize + (λ (tp) + (cond + [(and (tpromise? tp) (null? (tpromise-shape tp))) + (scalarize (tp-force tp))] + [else tp]))) + +(define scalar-shape + (λ (s0 [s1 '()]) '())) + +(define left-shape + (λ (s0 s1) s0)) + +(define right-shape + (λ (s0 s1) s1)) + +(define flat-function-maker2 + (λ (f m n) + (cond + ((and (zero? m) (zero? n)) + (λ (v0 i0 stride0 v1 i1 stride1 + v-out i-out stride-out) + (vset! v-out i-out + (f (vref v0 i0) (vref v1 i1))))) + (else + f)))) + +(define flat-function-maker1 + (λ (f m) + (cond + ((zero? m) + (λ (v0 i0 stride0 v-out i-out stride-out) + (vset! v-out i-out (f (vref v0 i0))))) + (else f)))) + +(define tp-ext1-∇ + (λ (f + m + [shape-fn scalar-shape] + [context 'lazy-d-ext1]) + (let ((flat-f + (flat-ext1-∇ (flat-gradient-maker1 f m) + m shape-fn))) + (λ (tp zp) + (cond + ((number? tp) (f tp zp)) + (else + (tpromise + (tcomp-ext1-∇ tp zp flat-f) + (tpromise-shape tp)))))))) + +(define tp-d-ext2^ + (λ (fᵈ r0 r1 shape-fn [context 'lazy-flat-d-ext2]) + (λ (tp-t0 tp-t1 tp-z) + (let* ((s0 (tpromise-shape tp-t0)) + (sf0 (min-shape r0 s0)) + (stride0 (flat:size-of sf0)) + + (s1 (tpromise-shape tp-t1)) + (sf1 (min-shape r1 s1)) + (stride1 (flat:size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (flat:size-of sf-z)) + + (out0 (box 'uncalculated)) + (out1 (box 'uncalculated)) + (forcer + (λ () + (let* ((f0 (ensure-flat (tp-force tp-t0))) + (f1 (ensure-flat (tp-force tp-t1))) + (fz (ensure-flat (tp-force tp-z))) + + (v0 (flat-store f0)) + (v1 (flat-store f1)) + (vz (flat-store fz)) + + (off0 (flat-offset f0)) + (off1 (flat-offset f1)) + (offz (flat-offset fz))) + (ext2-shapes + s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (flat:size-of + s0) + 0.0 + context)) + (g1 (new-vec (flat:size-of + s1) + 0.0 + context))) + (for ([iz (in-range + 0 + size-z + stride-z)]) + (let-values (((i0 i1) + (idxs + strides + iz + off0 + off1))) + (fᵈ g0 g1 v0 i0 + stride0 + v1 i1 + stride1 + vz + (+ offz iz) + stride-z))) + (set-box! out0 + (flat s0 g0 0)) + (set-box! out1 + (flat s1 g1 0))))))))) + (values + (tpromise (tcomp-ext2-∇ out0 forcer) s0) + (tpromise (tcomp-ext2-∇ out1 forcer) s1)))))) + +(define ensure-tpromise + (λ (v) + (cond + ((scalar? v) (tpromise (ensure-flat v) '())) + (else v)))) + +(define tp-ext2-∇ + (λ (f + m + n + [shape-fn scalar-shape] + [context 'lazy-d-ext2]) + (let ((tp-f + (let ((f (tp-d-ext2^ + (flat-gradient-maker2 f m n) + m n shape-fn))) + (λ (tp-t tp-u tp-z) + (let-values (((tp-dt tp-du) + (f tp-t tp-u tp-z))) + (values (tp-scalarize tp-dt) + (tp-scalarize tp-du))))))) + (λ (tp-t tp-u tp-z) + (tp-f (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + (ensure-tpromise tp-z)))))) + +(define flat-gradient-maker2 + (λ (f m n) + (cond + ((and (zero? m) (zero? n)) + (λ (g0 + g1 + v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (let ((z (vref vz iz)) + (a (vref v0 i0)) + (b (vref v1 i1))) + (let-values (((da db) (f a b z))) + (vset! g0 i0 + (+ (vref g0 i0) da)) + (vset! g1 i1 + (+ (vref g1 i1) db)))))) + (else f)))) + +(define flat-gradient-maker1 + (λ (f0 m) + (cond + ((zero? m) + (λ (g0 v0 i0 stride0 vz iz stride-z) + (let ((z (vref vz iz)) + (a (vref v0 i0))) + (vset! g0 i0 (+ (vref g0 i0) + (f0 a z)))))) + (else f0)))) + +(define tp-rank + (λ (tp) + (flat:len (tp-shape tp)))) + +(define tp-reshape + (λ (s tp) + (cond + ((= (flat:size-of s) (flat:size-of (tpromise-shape tp))) + (tpromise (tcomp-reshape s tp) s)) + (else (error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) + +(define tensor? + (lambda (tp) + (or (tpromise? tp) (scalar? tp)))) + +(include "test/test-0-lazy.rkt") + +(provide start-vector-manager vector-manager-report) + +(provide (rename-out + (flat:len len) + (flat:ref ref) + (flat:refr refr))) +(provide tensor + tp-force + (rename-out + (tp-scalarize scalarize) + (tp-tref tref) + (tp-tlen tlen) + (list->tpromise list->tensor) + (build-tpromise build-tensor) + (tp-trefs trefs))) + +(provide (rename-out + (tp-ext1-ρ ext1-ρ) + (tp-ext2-ρ ext2-ρ) + (tp-ext1-∇ ext1-∇) + (tp-ext2-∇ ext2-∇))) + +;; These will get overriden by duals +(provide tensor?) +(provide (rename-out + (tp-rank rank) + (tp-shape shape) + (tp-reshape reshape) + (flat:size-of size-of))) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt new file mode 100644 index 0000000..93851b1 --- /dev/null +++ b/lazy/tensors/A-equality.rkt @@ -0,0 +1,18 @@ +#lang racket + +(require "0-lazy.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + +(define tp-tensor-equal? + (λ (tp-actual tp-expected) + (flat:tensor-equal? (tp-force tp-actual) (tp-force tp-expected)))) + +(require rackunit) +(define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) + +(include "test/test-A-equality.rkt") + +(provide (rename-out + (flat:tolerance tolerance) + (tp-tensor-equal? tensor-equal?) + (tp-check-tensor-equal? check-tensor-equal?))) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt new file mode 100644 index 0000000..8ee0643 --- /dev/null +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -0,0 +1,227 @@ +(module+ test + (require rackunit) + + (define test-lt (tensor 1 2 3)) + (check-true (flat? (tpromise-tensor test-lt))) + (check-equal? (flat-store (tp-force test-lt)) (vector 1 2 3)) + (check-true (flat? (tpromise-tensor test-lt))) + (check-exn exn:fail? (λ () (tensor test-lt 4))) + (check-exn exn:fail? (λ () (tensor 4 test-lt))) + + (check-equal? (tp-tref test-lt 2) 3) + (check-exn exn:fail? (λ () (tp-tref test-lt 5))) + + (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) + (check-equal? (tp-tref (tp-tref test-nested-lt 0) 2) 3) + (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0)) 3) + (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2)) 3) + (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) + + (check-equal? (tp-tlen test-lt) 3) + (check-equal? (tp-tlen test-nested-lt) 2) + + (define test-lt-from-list (list->tpromise '(5 6 7 8))) + (check-equal? (flat-store (tp-force test-lt-from-list)) (vector 5 6 7 8)) + (define test-nested-lt-from-list + (list->tpromise `(,test-lt ,test-lt ,test-lt))) + (check-equal? (tpromise-shape test-nested-lt-from-list) '(3 3)) + + (check-true (bounded-idx*? test-nested-lt-from-list (list 0 1))) + (check-false (bounded-idx*? test-nested-lt-from-list (list 1 3))) + (check-false (bounded-idx*? test-nested-lt-from-list (list 1 1 0))) + + (define test-premap-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) + (define test-mapped-lt (tp-tmap add1 test-premap-lt)) + (check-true (flat? (tpromise-tensor test-premap-lt))) + (check-true (tcomp? (tpromise-tensor test-mapped-lt))) + (check-equal? (flat-store (tp-force test-mapped-lt)) (vector 2 3 4 5 6 7)) + (check-equal? (flat-shape (tp-force test-mapped-lt)) (flat-shape (tp-force test-premap-lt))) + (check-equal? (flat-offset (tp-force test-mapped-lt)) (flat-offset (tp-force test-premap-lt))) + (check-true (flat? (tpromise-tensor test-premap-lt))) + (check-true (flat? (tpromise-tensor test-mapped-lt))) + + (define test-build-shape '(4 3)) + (define test-built-tensor (build-tpromise test-build-shape + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column))))) + (check-equal? (tpromise-shape test-built-tensor) test-build-shape) + (check-true (tcomp? (tpromise-tensor test-built-tensor))) + + (define test-refs '(0 2)) + (define test-tp-trefs (tp-trefs test-built-tensor test-refs)) + (check-true (tcomp? (tpromise-tensor test-tp-trefs))) + (check-equal? (tpromise-shape test-tp-trefs) (flat-shape (tp-force test-tp-trefs))) + (check-equal? (flat-store (tp-force test-tp-trefs)) (vector 0 1 2 6 7 8)) + (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4))) 3) + + (define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + + (define sum (tp-ext1-ρ sum-f 1)) + (check-equal? (flat-store (tp-force (sum test-nested-lt))) (vec 6.0 15.0)) + + (define t0 + (build-tpromise '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + (define *-ρ (tp-ext2-ρ * 0 0)) + (define t0sqr (*-ρ t0 t0)) + + (flat:check-tensor-equal? (tp-force t0sqr) + (flat:reshape + '(2 3 4) + (flat:tensor + 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116))) + + (define *-2-1-f + (λ (v0 i0 s0 v1 i1 s1 vout iout sout) + (for ([j0 (in-range 0 s0)]) + (vset! vout (+ iout j0) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) + + (define t1 + (build-tpromise '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))) + + (define t2 + (build-tpromise '(6) + (λ (i) (* 3.0 (car i))))) + + (define *-2-1 + (tp-ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + + (define r-1-2 + (*-2-1 t1 t2)) + + (check-equal? (tpromise-shape r-1-2) '(5 6)) + (flat:check-tensor-equal? (tp-force r-1-2) + (flat:reshape + '(5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) + + (define t3 + (build-tpromise '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + + (define t4 + (build-tpromise '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y))))) + + (define r-3-4 + (*-2-1 t3 t4)) + + (check-equal? (tpromise-shape r-3-4) '(3 5 6)) + (flat:check-tensor-equal? (tp-force r-3-4) + (flat:reshape + '(3 5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + + (define r1-td (tensor 3.0 4.0 5.0)) + (define r2-td (tp-reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) + + (define +ᶠ +) + (define +ᵈ (λ (a b z) (values z z))) + + (define sqrᶠ (λ (a) (* a a))) + (define sqrᵈ + (λ (a z) (* z 2 a))) + + (define d-sqr (tp-ext1-∇ sqrᵈ 0 scalar-shape)) + + (define one-like + (λ (t) + (build-tpromise (tpromise-shape t) (λ (_) 1.0)))) + + (flat:check-tensor-equal? (tp-force (d-sqr r1-td (one-like r1-td))) + (flat:tensor 6.0 8.0 10.0)) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (flat:check-tensor-equal? (tp-force gsqr) + (flat:reshape + '(2 3) + (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + + (define d+ (tp-ext2-∇ +ᵈ 0 0 scalar-shape)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (flat:check-tensor-equal? (tp-force da) + (flat:tensor 1.0 1.0 1.0)) + (flat:check-tensor-equal? (tp-force db) + (flat:tensor 1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (flat:check-tensor-equal? (tp-force da) + (flat:tensor 2.0 2.0 2.0)) + (flat:check-tensor-equal? (tp-force db) + (flat:reshape + '(2 3) + (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + + (define *∇ (tp-ext2-∇ (λ (a b z) (values (* z b) (* z a))) + 0 + 0)) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0 1.0)))) + (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor 1.0 2.0 3.0))) + (flat:check-tensor-equal? (tp-force gu) (tp-force (tensor 2.0 3.0 4.0)))) + + (define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + + (define sum-∇ (tp-ext1-∇ sum-1-∇ 1 (λ (s) '()))) + + (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) + 1.0))) + (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor 1.0 1.0 1.0)))) + + (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0)))) + (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0)))))) diff --git a/lazy/tensors/test/test-A-equality.rkt b/lazy/tensors/test/test-A-equality.rkt new file mode 100644 index 0000000..9c87732 --- /dev/null +++ b/lazy/tensors/test/test-A-equality.rkt @@ -0,0 +1,49 @@ +(module+ test + (require rackunit) + + (define t0 + (reshape '(2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.0 (car i)))))) + + + (define t1 + (reshape '(2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.000001 (car i)))))) + + (define t2 + (reshape '(1 2 3 4) + (build-tensor + '(24) + (λ (i) + (* 2.000001 (car i)))))) + + (define t3 + (reshape '(2 2 3 4) + (build-tensor + '(48) + (λ (i) + (* (quotient (car i) 24) (car i)))))) + + (define t4 + (reshape '(2 2 3 4) + (build-tensor + '(48) + (λ (i) + (- (* 2.000001 (* (quotient (car i) 24) (car i))) 48.0))))) + + (check-true (tp-tensor-equal? t0 t1)) + + (check-false (tp-tensor-equal? t0 t2)) ;; elements are equal, but shapes are not + + (check-true (tp-tensor-equal? t0 (reshape '(2 3 4) + t2))) + + (tp-check-tensor-equal? t0 t1) + + (tp-check-tensor-equal? t0 (reshape '(2 3 4) t2))) diff --git a/malted/test/test-O-init.rkt b/malted/test/test-O-init.rkt index 5b142c0..342300e 100644 --- a/malted/test/test-O-init.rkt +++ b/malted/test/test-O-init.rkt @@ -1,16 +1,18 @@ (module+ test (require rackunit) + ;; TODO: Make this better. We musn't break abstraction boundaries + (require "../lazy/tensors/0-lazy.rkt") (define v (init-shape (list 1000 4))) - (define mean-v (abs (/ (sum (sum v)) 4000))) - (define variance-v (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v))) + (define mean-v (tp-force (abs (/ (sum (sum v)) 4000)))) + (define variance-v (tp-force (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v)))) (check-true (< mean-v 0.05)) (check-true (and (>= variance-v 0.4) (<= variance-v 0.6))) ;; Here variance will be 2/8 = 0.25 (define r (init-shape (list 1000 4 2))) - (define mean-r (abs (/ (sum (sum (sum r))) 8000))) - (define variance-r (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r))) + (define mean-r (tp-force (abs (/ (sum (sum (sum r))) 8000)))) + (define variance-r (tp-force (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r)))) (check-true (< mean-r 0.05)) (check-true (and (>= variance-r 0.22) From dde397018333125edd8f5196bab911fe4ae3eadf Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 20:42:56 -0400 Subject: [PATCH 53/83] [add-lazy]Make fixes after rebase into main --- flat-tensors/ext-impl.rkt | 4 + flat-tensors/tensors/D-extend.rkt | 2 + lazy.rkt | 10 +- lazy/autodiff.rkt | 2 +- lazy/autodiff/A-autodiff.rkt | 8 +- lazy/autodiff/E-print.rkt | 10 +- lazy/ext-ops.rkt | 5 +- lazy/ext-ops/A-scalar-ops.rkt | 12 +- lazy/ext-ops/B-comparators.rkt | 2 +- lazy/ext-ops/I-flatten.rkt | 31 +++ lazy/ext-ops/test/test-I-flatten.rkt | 13 + lazy/tensors.rkt | 3 +- lazy/tensors/0-lazy.rkt | 354 +++++++++++++++------------ lazy/tensors/test/test-0-lazy.rkt | 4 + 14 files changed, 277 insertions(+), 183 deletions(-) create mode 100644 lazy/ext-ops/I-flatten.rkt create mode 100644 lazy/ext-ops/test/test-I-flatten.rkt diff --git a/flat-tensors/ext-impl.rkt b/flat-tensors/ext-impl.rkt index b79efe3..2530b4f 100644 --- a/flat-tensors/ext-impl.rkt +++ b/flat-tensors/ext-impl.rkt @@ -10,6 +10,10 @@ flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ + functional->preallocated-1-ρ + functional->preallocated-1-∇ + functional->preallocated-2-ρ + functional->preallocated-2-∇ idxs)) (require (only-in "autodiff/E-print.rkt" make-printable-flat diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index 8d6c068..edca5fa 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -386,5 +386,7 @@ (include "test/test-D-extend.rkt") (provide ext1-ρ ext1-∇ ext2-ρ ext2-∇ expects-preallocated? + functional->preallocated-1-ρ functional->preallocated-1-∇ + functional->preallocated-2-ρ functional->preallocated-2-∇ merge-shapes min-shape ext2-shapes idxs flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ) diff --git a/lazy.rkt b/lazy.rkt index 5fdf862..84c64cc 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -14,7 +14,7 @@ ext1-ρ ext2-ρ ext1-∇ ext2-∇ - dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* ext1 ext2 prim1 prim2 @@ -25,17 +25,19 @@ (rename-out (d+ +) (d- -) (d* *) (d/ /) (d-rectify rectify) (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) - (d-max max) (d-sum-cols sum-cols) (d-correlate correlate)) + (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) + (d-flatten flatten)) +-ρ --ρ *-ρ /-ρ rectify-ρ exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ + flatten-ρ +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 - exp-0 log-0 abs-0 rectify-0 + exp-0 log-0 sqrt-0 abs-0 rectify-0 - sum-1 argmax-1 max-1 + flatten-2 sum-1 argmax-1 max-1 =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/autodiff.rkt b/lazy/autodiff.rkt index 66f8f3a..93d4384 100644 --- a/lazy/autodiff.rkt +++ b/lazy/autodiff.rkt @@ -6,7 +6,7 @@ (require "autodiff/D-test-helpers.rkt") (require "autodiff/E-print.rkt") -(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual*) +(provide dual dual? ρ κ ∇ ∇¹ scalar? trace-print dual* map*) (provide prim1 prim2 ext1 ext2) (provide (rename-out (d-rank rank) (d-shape shape) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index 51a54ef..e3f75f1 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -21,13 +21,15 @@ (define ρ (λ (d) (cond - ((dual? d) (vector-ref d 1)) + ((dual? d) (tp-force (vector-ref d 1)) + (vector-ref d 1)) (else d)))) (define κ (λ (d) (cond - ((dual? d) (vector-ref d 2)) + ((dual? d) (tp-force (vector-ref d 2)) + (vector-ref d 2)) (else end-of-chain)))) (define scalar? @@ -117,4 +119,4 @@ (provide dual dual? ρ κ ∇ ∇¹ dual* scalar? end-of-chain - trace-print) + trace-print map*) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 6c85c70..66494d4 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -10,19 +10,15 @@ (λ (y [max-length (max-tensor-print-length)]) (cond ((dual? y) (make-printable (ρ y))) - ((and (not (scalar? y)) (tensor? y)) - (make-printable-tp y max-length)) + ((tpromise? y) + (make-printable (tp-force y) max-length)) + ((flat? y) (make-printable-flat y max-length)) ((list? y) (map (λ (le) (make-printable le max-length)) y)) ((vector? y) (vector-map (λ (ve) (make-printable ve max-length)) y)) (else y)))) -(define make-printable-tp - (λ (y [max-length (max-tensor-print-length)]) - (make-printable-flat (tp-force y) max-length))) - - (include "test/test-E-print.rkt") (provide max-tensor-print-length diff --git a/lazy/ext-ops.rkt b/lazy/ext-ops.rkt index d3dff13..7bb112f 100644 --- a/lazy/ext-ops.rkt +++ b/lazy/ext-ops.rkt @@ -7,13 +7,14 @@ (require "ext-ops/E-argmax.rkt") (require "ext-ops/F-max.rkt") (require "ext-ops/G-correlate.rkt") +(require "ext-ops/I-flatten.rkt") (provide d+ d- d* d/ d-expt d-exp d-log d-abs d-rectify d-sqrt d-sqr +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 - exp-0 log-0 abs-0 rectify-0 + exp-0 log-0 sqrt-0 abs-0 rectify-0 +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ @@ -31,3 +32,5 @@ (provide max-1 d-max max-ρ) (provide correlate-ρ d-correlate) + +(provide flatten-2 d-flatten flatten-ρ) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt index 9c46948..06bd0f9 100644 --- a/lazy/ext-ops/A-scalar-ops.rkt +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -40,6 +40,11 @@ (λ (a z) (* z (/ 1 a))))) +(define sqrt-0 + (prim1 sqrt + (λ (x z) + (/ z (* 2 (sqrt x)))))) + (define abs-0-ρ (λ (x) (cond @@ -87,10 +92,7 @@ (define d-log (ext1 log-0 0)) (define d-abs (ext1 abs-0 0)) (define d-rectify (ext1 rectify-0 0)) - -(define d-sqrt - (λ (a) - (d-expt a 1/2))) +(define d-sqrt (ext1 sqrt-0 0)) (define d-sqr (λ (x) @@ -122,7 +124,7 @@ (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 - exp-0 log-0 abs-0 rectify-0 + exp-0 log-0 sqrt-0 abs-0 rectify-0 d+ d- d* d/ d-expt d-exp d-log d-abs diff --git a/lazy/ext-ops/B-comparators.rkt b/lazy/ext-ops/B-comparators.rkt index c42a2cf..7fcb184 100644 --- a/lazy/ext-ops/B-comparators.rkt +++ b/lazy/ext-ops/B-comparators.rkt @@ -24,7 +24,7 @@ (comparator >)) (define >=-0-0 - (comparator >)) + (comparator >=)) ;;---------------------------- ;; Tensorized comparators diff --git a/lazy/ext-ops/I-flatten.rkt b/lazy/ext-ops/I-flatten.rkt new file mode 100644 index 0000000..bf24773 --- /dev/null +++ b/lazy/ext-ops/I-flatten.rkt @@ -0,0 +1,31 @@ +#lang racket + +(require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) +(require (only-in "../autodiff.rkt" prim1 ext1)) + +(define flatten-2-ρ + (λ (t) + (reshape (flatten-shape (shape t)) t))) + +(define flatten-2-∇ + (λ (t z) + (reshape (shape t) z))) + +(define flatten-shape + (λ (s) + (let ((rows (ref s 0)) + (cols (ref s 1))) + (list (* rows cols))))) + +(define flatten-2 + (prim1 flatten-2-ρ flatten-2-∇ flatten-shape)) + +(define d-flatten + (ext1 flatten-2 2)) + +(define flatten-ρ + (ext1-ρ flatten-2-ρ 2)) + +(include "test/test-I-flatten.rkt") + +(provide flatten-2 d-flatten flatten-ρ) diff --git a/lazy/ext-ops/test/test-I-flatten.rkt b/lazy/ext-ops/test/test-I-flatten.rkt new file mode 100644 index 0000000..f7740cf --- /dev/null +++ b/lazy/ext-ops/test/test-I-flatten.rkt @@ -0,0 +1,13 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0)) + + (check-dual-equal? (flatten-2 r2-t1) r1-t1) + (check-ρ-∇ ((λ (t1 t2) (d* t1 (flatten-2 t2))) r1-t1 r2-t1) + (tensor 9.0 16.0 25.0 36.0) + (list (tensor 3.0 4.0 5.0 6.0) (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))))) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index 9461f5b..d2faac9 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -11,8 +11,7 @@ (provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) -;; TODO: figure out why was this exported in flat-tensors -;;(provide flat? flat-shape flat-store flat-offset size-of strides) +(provide tp-force) ;; These will get overriden by duals (provide tensor?) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index a6cedad..bcbb5be 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -4,16 +4,18 @@ ;; tensor computations (struct tcomp ()) -;; TODO: figure out if removing tcom-tensor is a good idea -(struct tcomp-tensor tcomp (t-shape t-flat) #:transparent) (struct tcomp-list->tpromise-list tcomp (lst) #:transparent) (struct tcomp-tp-map tcomp (f tp) #:transparent) (struct tcomp-build-tpromise tcomp (s f) #:transparent) (struct tcomp-tp-trefs tcomp (forced b) #:transparent) (struct tcomp-ext2-∇ tcomp (b forcer) #:transparent) -(struct tcomp-ext1-∇ tcomp (tp zp flat-f) #:transparent) -(struct tcomp-ext2-ρ tcomp (tp-t tp-u flat-f) #:transparent) -(struct tcomp-ext1-ρ tcomp (tp flat-f) #:transparent) +(struct tcomp-ext1-∇-prealloc tcomp (tp zp f m shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) +(struct tcomp-ext2-ρ-prealloc tcomp (tp-t tp-u f m n shape-fn) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) +(struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) +(struct tcomp-ext1-ρ-prealloc tcomp (f m shape-fn tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) (struct tpromise ((tensor #:mutable) shape) @@ -96,8 +98,6 @@ (define tcomp-force (λ (tc) (match tc - #;[(tcomp-tensor t-shape t-flat) - (flat t-shape t-flat 0)] [(tcomp-list->tpromise-list lst) (flat:list->tensor (map (λ (l) (tp-force l #f)) lst))] @@ -119,18 +119,45 @@ (forcer) (unbox b)) (else v)))] - [(tcomp-ext1-∇ tp zp flat-f) - (let ([t (tp-force tp #f)] - [z (tp-force zp #f)]) - (scalarize (flat-f (ensure-flat t) (ensure-flat z))))] - [(tcomp-ext2-ρ tp-t tp-u flat-f) - (let ([t (tp-force tp-t #f)] - [u (tp-force tp-u #f)]) - (scalarize (flat-f (ensure-flat t) (ensure-flat u))))] - [(tcomp-ext1-ρ tp flat-f) - (let ([t (tp-force tp #f)]) - (let ([res (scalarize (flat-f t))]) - res))] + [(tcomp-ext1-∇-prealloc tp zp f m shape-fn) + (scalarize + (flat-ext1-∇ f m shape-fn + (ensure-flat (tp-force tp)) + (ensure-flat (tp-force zp))))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (let ([t (ensure-flat (tp-force tp #f))] + [z (ensure-flat (tp-force zp #f))]) + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] + [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) + (scalarize + (flat-ext2-ρ f m n shape-fn + (ensure-flat (tp-force tp-t)) + (ensure-flat (tp-force tp-u))))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (let ([t (ensure-flat (tp-force tp-t #f))] + [u (ensure-flat (tp-force tp-u #f))]) + (let* ((t-shape (min-shape m (flat-shape t))) + (u-shape (min-shape n (flat-shape u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f m n shape-fn t u))))] + [(tcomp-ext1-ρ-scalar f tp) + (f (tp-force tp))] + [(tcomp-ext1-ρ-prealloc f m shape-fn tp) + (scalarize (flat-ext1-ρ f m shape-fn (ensure-flat (tp-force tp))))] + [(tcomp-ext1-ρ f m shape-fn tp) + (let ([t (ensure-flat (tp-force tp #f))]) + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f m shape-fn t))))] [(tcomp-reshape s tp) (let ([t (tp-force tp #f)]) (flat s (flat-store t) (flat-offset t)))]))) @@ -217,55 +244,65 @@ . ,(cdr (flat-shape forced)))))]))) (define tp-ext1-ρ - (λ (f - m - [shape-fn scalar-shape] - [context 'lazy-ext1]) - (let ((flat-f - (flat-ext1-ρ (flat-function-maker1 f m) - m shape-fn context))) - (λ (tp) - (cond - [(scalar? tp) (f tp)] - [(and (tpromise? tp) - (null? (tpromise-shape tp))) - (f (tp-force tp))] - [else - (tpromise - (tcomp-ext1-ρ tp flat-f) - (merge-shapes - (tpromise-shape tp) - m - (shape-fn - (min-shape m (tpromise-shape tp)))))]))))) + (λ (f m [shape-fn scalar-shape]) + (λ (tp) + (cond + [(scalar? tp) (f tp)] + [(and (tpromise? tp) + (null? (tpromise-shape tp))) + (tpromise + (tcomp-ext1-ρ-scalar f tp) + '())] + [(flat:expects-preallocated? f) + (tpromise + (tcomp-ext1-ρ-prealloc f m shape-fn tp) + (merge-shapes + (tp-shape tp) + m + (shape-fn + (min-shape m (tp-shape tp)))))] + [else + (tpromise + (tcomp-ext1-ρ f m shape-fn tp) + (merge-shapes + (tp-shape tp) + m + (shape-fn + (min-shape m (tp-shape tp)))))])))) (define tp-ext2-ρ - (λ (f - m - n - [shape-fn scalar-shape] - [context 'raw-ext2]) - (let ((flat-f - (flat-ext2-ρ (flat-function-maker2 f m n) - m n shape-fn #;context))) - (λ (tp-t tp-u) - (cond - ((and (number? tp-t) (number? tp-u)) - (f tp-t tp-u)) - [(and (tpromise? tp-t) (tpromise? tp-u) - (null? (tpromise-shape tp-t)) - (null? (tpromise-shape tp-u))) - (f (tp-force tp-t) (tp-force tp-u))] - [else - (let* ([s0 (tp-shape tp-t)] - [s1 (tp-shape tp-u)] - [sf0 (min-shape m s0)] - [sf1 (min-shape n s1)] - [sf-out (shape-fn sf0 sf1)]) - (tpromise - (tcomp-ext2-ρ tp-t tp-u flat-f) - (ext2-shapes s0 s1 m n sf-out - (λ (s-out . _) s-out))))]))))) + (λ (f m n [shape-fn scalar-shape]) + (λ (tp-t tp-u) + (cond + ((and (number? tp-t) (number? tp-u)) + (f tp-t tp-u)) + [(and (tpromise? tp-t) (tpromise? tp-u) + (null? (tpromise-shape tp-t)) + (null? (tpromise-shape tp-u))) + ;; TODO: move this to a tcomp + (f (tp-force tp-t) (tp-force tp-u))] + [(flat:expects-preallocated? f) + (let* ([s0 (tp-shape tp-t)] + [s1 (tp-shape tp-u)] + [sf0 (min-shape m s0)] + [sf1 (min-shape n s1)] + [sf-out (shape-fn sf0 sf1)]) + (tpromise + (tcomp-ext2-ρ-prealloc tp-t tp-u + f m n shape-fn) + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))] + [else + (let* ([s0 (tp-shape tp-t)] + [s1 (tp-shape tp-u)] + [sf0 (min-shape m s0)] + [sf1 (min-shape n s1)] + [sf-out (shape-fn sf0 sf1)]) + (tpromise + (tcomp-ext2-ρ (ensure-tpromise tp-t) (ensure-tpromise tp-u) + f m n shape-fn) + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))])))) (define scalarize (λ (t) @@ -310,85 +347,80 @@ (else f)))) (define tp-ext1-∇ - (λ (f - m - [shape-fn scalar-shape] - [context 'lazy-d-ext1]) - (let ((flat-f - (flat-ext1-∇ (flat-gradient-maker1 f m) - m shape-fn))) - (λ (tp zp) - (cond - ((number? tp) (f tp zp)) - (else - (tpromise - (tcomp-ext1-∇ tp zp flat-f) - (tpromise-shape tp)))))))) + (λ (f m [shape-fn scalar-shape]) + (λ (tp zp) + (cond + ((number? tp) (f tp zp)) + ((flat:expects-preallocated? f) + (tpromise + (tcomp-ext1-∇-prealloc tp zp f m shape-fn) + (tp-shape tp))) + (else + (tpromise + (tcomp-ext1-∇ tp zp f m shape-fn) + (tp-shape tp))))))) (define tp-d-ext2^ - (λ (fᵈ r0 r1 shape-fn [context 'lazy-flat-d-ext2]) - (λ (tp-t0 tp-t1 tp-z) - (let* ((s0 (tpromise-shape tp-t0)) - (sf0 (min-shape r0 s0)) - (stride0 (flat:size-of sf0)) - - (s1 (tpromise-shape tp-t1)) - (sf1 (min-shape r1 s1)) - (stride1 (flat:size-of sf1)) - - (sf-z (shape-fn sf0 sf1)) - (stride-z (flat:size-of sf-z)) - - (out0 (box 'uncalculated)) - (out1 (box 'uncalculated)) - (forcer - (λ () - (let* ((f0 (ensure-flat (tp-force tp-t0))) - (f1 (ensure-flat (tp-force tp-t1))) - (fz (ensure-flat (tp-force tp-z))) - - (v0 (flat-store f0)) - (v1 (flat-store f1)) - (vz (flat-store fz)) - - (off0 (flat-offset f0)) - (off1 (flat-offset f1)) - (offz (flat-offset fz))) - (ext2-shapes - s0 s1 r0 r1 sf-z - (λ (sz size-z q0 q1 strides) - (let ((g0 (new-vec (flat:size-of - s0) - 0.0 - context)) - (g1 (new-vec (flat:size-of - s1) - 0.0 - context))) - (for ([iz (in-range - 0 - size-z - stride-z)]) - (let-values (((i0 i1) - (idxs - strides - iz - off0 - off1))) - (fᵈ g0 g1 v0 i0 - stride0 - v1 i1 - stride1 - vz - (+ offz iz) - stride-z))) - (set-box! out0 - (flat s0 g0 0)) - (set-box! out1 - (flat s1 g1 0))))))))) - (values - (tpromise (tcomp-ext2-∇ out0 forcer) s0) - (tpromise (tcomp-ext2-∇ out1 forcer) s1)))))) + (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (let* ((s0 (tp-shape tp-t0)) + (sf0 (min-shape r0 s0)) + (stride0 (flat:size-of sf0)) + + (s1 (tp-shape tp-t1)) + (sf1 (min-shape r1 s1)) + (stride1 (flat:size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (flat:size-of sf-z)) + + (out0 (box 'uncalculated)) + (out1 (box 'uncalculated)) + (forcer + (λ () + (let* ((f0 (ensure-flat (tp-force tp-t0))) + (f1 (ensure-flat (tp-force tp-t1))) + (fz (ensure-flat (tp-force tp-z))) + + (v0 (flat-store f0)) + (v1 (flat-store f1)) + (vz (flat-store fz)) + + (off0 (flat-offset f0)) + (off1 (flat-offset f1)) + (offz (flat-offset fz))) + (ext2-shapes + s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (flat:size-of + s0) + 0.0)) + (g1 (new-vec (flat:size-of + s1) + 0.0))) + (for ([iz (in-range + 0 + size-z + stride-z)]) + (let-values (((i0 i1) + (idxs + strides + iz + off0 + off1))) + (fᵈ g0 g1 v0 i0 + stride0 + v1 i1 + stride1 + vz + (+ offz iz) + stride-z))) + (set-box! out0 + (flat s0 g0 0)) + (set-box! out1 + (flat s1 g1 0))))))))) + (values + (tpromise (tcomp-ext2-∇ out0 forcer) s0) + (tpromise (tcomp-ext2-∇ out1 forcer) s1))))) (define ensure-tpromise (λ (v) @@ -397,24 +429,27 @@ (else v)))) (define tp-ext2-∇ - (λ (f - m - n - [shape-fn scalar-shape] - [context 'lazy-d-ext2]) + (λ (f m n [shape-fn scalar-shape]) (let ((tp-f - (let ((f (tp-d-ext2^ - (flat-gradient-maker2 f m n) - m n shape-fn))) - (λ (tp-t tp-u tp-z) - (let-values (((tp-dt tp-du) - (f tp-t tp-u tp-z))) - (values (tp-scalarize tp-dt) - (tp-scalarize tp-du))))))) + (λ (f tp-t tp-u tp-z) + (let-values (((tp-dt tp-du) + (tp-d-ext2^ f m n shape-fn + tp-t tp-u tp-z))) + (values (tp-scalarize tp-dt) + (tp-scalarize tp-du)))))) (λ (tp-t tp-u tp-z) - (tp-f (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - (ensure-tpromise tp-z)))))) + (cond + ((flat:expects-preallocated? f) + (tp-f f tp-t tp-u tp-z)) + [else (let* ((t-shape (min-shape m (tp-shape tp-t))) + (u-shape (min-shape n (tp-shape tp-u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ + f t-shape u-shape out-shape))) + (tp-f flat-f + (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + (ensure-tpromise tp-z)))]))))) (define flat-gradient-maker2 (λ (f m n) @@ -471,6 +506,7 @@ (flat:refr refr))) (provide tensor tp-force + tpromise? (rename-out (tp-scalarize scalarize) (tp-tref tref) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index 8ee0643..cb7ccb9 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -67,6 +67,10 @@ (define sum (tp-ext1-ρ sum-f 1)) (check-equal? (flat-store (tp-force (sum test-nested-lt))) (vec 6.0 15.0)) + (define id-f (lambda (v) v)) + (define id-ρ (tp-ext1-ρ id-f 1 (λ (s) s))) + (check-equal? (flat-store (tp-force (id-ρ test-nested-lt))) (vec 1 2 3 4 5 6)) + (define t0 (build-tpromise '(2 3 4) (λ (i) From 1a78b7c923c67bb3fd72878cc422e9e4ca24fd60 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 1 Apr 2023 11:30:59 -0400 Subject: [PATCH 54/83] [add-lazy]Add TODOs --- lazy/tensors/0-lazy.rkt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index bcbb5be..354e5ca 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -2,6 +2,9 @@ (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +;; TODO: Ensure that any calls to tp-force are occurring only as a recursive +;; call in this file. The only place we tp-force is in ρ, κ and print functions. + ;; tensor computations (struct tcomp ()) (struct tcomp-list->tpromise-list tcomp (lst) #:transparent) @@ -48,6 +51,7 @@ (λ (args) (cond [(number? (car args)) (apply flat:tensor args)] + ;; TODO: Add the below map as another tcomp [else (merge-flats (map tp-force args))]))) (define ensure-shape @@ -360,6 +364,9 @@ (tcomp-ext1-∇ tp zp f m shape-fn) (tp-shape tp))))))) +;; TODO: make sure that all functions being stored in tcomp structs should be +;; prims. Other functions should be inlined into tp-tcomp by passing all +;; parameters to these functions through the tcomp struct. (define tp-d-ext2^ (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((s0 (tp-shape tp-t0)) From c544fa8cba51859f87f96a5f92ce0918994a71f5 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 1 Apr 2023 12:57:32 -0400 Subject: [PATCH 55/83] [add-lazy]Fix iris and test cases --- lazy/autodiff/A-autodiff.rkt | 6 +++--- lazy/autodiff/D-test-helpers.rkt | 9 +++++---- lazy/tensors.rkt | 2 +- lazy/tensors/0-lazy.rkt | 2 +- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index e3f75f1..51c3f38 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -21,15 +21,15 @@ (define ρ (λ (d) (cond + ((tensor? d) (tp-force d) d) ((dual? d) (tp-force (vector-ref d 1)) - (vector-ref d 1)) + (scalarize (vector-ref d 1))) (else d)))) (define κ (λ (d) (cond - ((dual? d) (tp-force (vector-ref d 2)) - (vector-ref d 2)) + ((dual? d) (tp-force (vector-ref d 2))) (else end-of-chain)))) (define scalar? diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index c017981..f504e0e 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -13,16 +13,17 @@ (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) (let* ((y (tp-force (apply fn args))) - (g (tp-force (apply (∇¹ fn) args)))) + (g (tp-force (apply (∇¹ fn) args))) + (ans-ρ (ρ ans))) (cond - ((and (equal-wt? ans (ρ y)) + ((and (equal-wt? ans-ρ (ρ y)) (equal-wt? grads (ρ g))) (void)) - ((equal-wt? ans (ρ y)) + ((equal-wt? ans-ρ (ρ y)) (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" (ρ g) grads))) (else (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" - (ρ y) ans)))))) + (ρ y) ans-ρ)))))) (define-syntax check-ρ-∇ (syntax-rules () diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index d2faac9..9b06ad2 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -11,7 +11,7 @@ (provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) -(provide tp-force) +(provide tp-force scalarize) ;; These will get overriden by duals (provide tensor?) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 354e5ca..4aa532e 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -311,7 +311,7 @@ (define scalarize (λ (t) (cond - ((null? (flat-shape t)) + ((and (flat? t) (null? (flat-shape t))) (vref (flat-store t) 0)) (else t)))) From 78194236e412e6d06359236a35fc1977c6be415d Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 8 Apr 2023 12:49:38 -0400 Subject: [PATCH 56/83] [add-lazy]Move functions around --- lazy/autodiff/B-prims.rkt | 7 +- lazy/tensors/0-lazy.rkt | 189 ++++++++++++++++-------------- lazy/tensors/test/test-0-lazy.rkt | 2 +- 3 files changed, 107 insertions(+), 91 deletions(-) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 879f28f..09ee621 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -30,7 +30,7 @@ (dual (ρ-fn ra) (λ (d z σ) (let ((ga (∇-fn ra z))) - ((κ da) da ga σ))))))) + ((κ da) da ga #;(tp-force ga) σ))))))) (define prim2 (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) @@ -49,8 +49,9 @@ (dual (ρ-fn ra rb) (λ (d z σ) (let-values (((ga gb) (∇-fn ra rb z))) - (let ((σ-hat ((κ da) da ga σ))) - ((κ db) db gb σ-hat)))))))) + ;; TODO: remove the tp-force here + (let ((σ-hat ((κ da) da (tp-force ga) σ))) + ((κ db) db (tp-force gb) σ-hat)))))))) ;;---------------------------- ;; Dualized tensor op creators diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 4aa532e..3cb97ca 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -2,8 +2,6 @@ (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) -;; TODO: Ensure that any calls to tp-force are occurring only as a recursive -;; call in this file. The only place we tp-force is in ρ, κ and print functions. ;; tensor computations (struct tcomp ()) @@ -11,15 +9,18 @@ (struct tcomp-tp-map tcomp (f tp) #:transparent) (struct tcomp-build-tpromise tcomp (s f) #:transparent) (struct tcomp-tp-trefs tcomp (forced b) #:transparent) -(struct tcomp-ext2-∇ tcomp (b forcer) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + #:transparent) (struct tcomp-ext1-∇-prealloc tcomp (tp zp f m shape-fn) #:transparent) (struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f tp-t tp-u) #:transparent) (struct tcomp-ext2-ρ-prealloc tcomp (tp-t tp-u f m n shape-fn) #:transparent) (struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) (struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) (struct tcomp-ext1-ρ-prealloc tcomp (f m shape-fn tp) #:transparent) (struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) +(struct tcomp-tensor tcomp (args) #:transparent) (struct tpromise ((tensor #:mutable) shape) #:guard @@ -44,15 +45,22 @@ (define tensor (λ args (ensure-shape args) - (let ([inner-flat (tensor-inner-flat args)]) - (tpromise inner-flat (flat:shape inner-flat))))) + (let ((inner-flat (tensor-inner-flat args)) + ) + (cond + ((flat? inner-flat) + (tpromise inner-flat (flat-shape inner-flat))) + (else + (let* ((inner-shape (tpromise-shape (car args))) + (outer (length args)) + (new-shape (cons outer inner-shape))) + (tpromise inner-flat new-shape))))))) (define tensor-inner-flat (λ (args) (cond [(number? (car args)) (apply flat:tensor args)] - ;; TODO: Add the below map as another tcomp - [else (merge-flats (map tp-force args))]))) + [else (tcomp-tensor args)]))) (define ensure-shape (λ (args) @@ -116,11 +124,12 @@ (flat:build-tensor s f)] [(tcomp-tp-trefs forced b) (flat:trefs forced b)] - [(tcomp-ext2-∇ b forcer) - (let ([v (unbox b)]) + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let* ([b (if (zero? i) out0 out1)] + [v (unbox b)]) (cond ((eqv? v 'uncalculated) - (forcer) + (ext2-∇-forcer fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1) (unbox b)) (else v)))] [(tcomp-ext1-∇-prealloc tp zp f m shape-fn) @@ -136,6 +145,8 @@ (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) (scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (f (tp-force tp-t) (tp-force tp-t))] [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) (scalarize (flat-ext2-ρ f m n shape-fn @@ -164,8 +175,12 @@ (flat-ext1-ρ flat-f m shape-fn t))))] [(tcomp-reshape s tp) (let ([t (tp-force tp #f)]) - (flat s (flat-store t) (flat-offset t)))]))) + (flat s (flat-store t) (flat-offset t)))] + [(tcomp-tensor args) + + (merge-flats (map tp-force args))]))) +;; TODO: This can also be made lazy (define tp-force-ref (λ (tp i) (flat:tref (tp-force tp) i))) @@ -283,8 +298,7 @@ [(and (tpromise? tp-t) (tpromise? tp-u) (null? (tpromise-shape tp-t)) (null? (tpromise-shape tp-u))) - ;; TODO: move this to a tcomp - (f (tp-force tp-t) (tp-force tp-u))] + (tpromise (tcomp-ext2-ρ-scalar f tp-t tp-u) '())] [(flat:expects-preallocated? f) (let* ([s0 (tp-shape tp-t)] [s1 (tp-shape tp-u)] @@ -364,11 +378,28 @@ (tcomp-ext1-∇ tp zp f m shape-fn) (tp-shape tp))))))) -;; TODO: make sure that all functions being stored in tcomp structs should be -;; prims. Other functions should be inlined into tp-tcomp by passing all -;; parameters to these functions through the tcomp struct. -(define tp-d-ext2^ - (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) +(define tp-ext2-∇ + (λ (f m n [shape-fn scalar-shape]) + (let ((tp-f + (λ (f tp-t tp-u tp-z) + (tp-d-ext2^ f m n shape-fn + tp-t tp-u tp-z)))) + (λ (tp-t tp-u tp-z) + (cond + ((flat:expects-preallocated? f) + (tp-f f tp-t tp-u tp-z)) + [else (let* ((t-shape (min-shape m (tp-shape tp-t))) + (u-shape (min-shape n (tp-shape tp-u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ + f t-shape u-shape out-shape))) + (tp-f flat-f + (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + (ensure-tpromise tp-z)))]))))) + +(define ext2-∇-forcer + (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1) (let* ((s0 (tp-shape tp-t0)) (sf0 (min-shape r0 s0)) (stride0 (flat:size-of sf0)) @@ -380,54 +411,60 @@ (sf-z (shape-fn sf0 sf1)) (stride-z (flat:size-of sf-z)) - (out0 (box 'uncalculated)) - (out1 (box 'uncalculated)) - (forcer - (λ () - (let* ((f0 (ensure-flat (tp-force tp-t0))) - (f1 (ensure-flat (tp-force tp-t1))) - (fz (ensure-flat (tp-force tp-z))) - - (v0 (flat-store f0)) - (v1 (flat-store f1)) - (vz (flat-store fz)) - - (off0 (flat-offset f0)) - (off1 (flat-offset f1)) - (offz (flat-offset fz))) - (ext2-shapes - s0 s1 r0 r1 sf-z - (λ (sz size-z q0 q1 strides) - (let ((g0 (new-vec (flat:size-of - s0) - 0.0)) - (g1 (new-vec (flat:size-of - s1) - 0.0))) - (for ([iz (in-range - 0 - size-z - stride-z)]) - (let-values (((i0 i1) - (idxs - strides - iz - off0 - off1))) - (fᵈ g0 g1 v0 i0 - stride0 - v1 i1 - stride1 - vz - (+ offz iz) - stride-z))) - (set-box! out0 - (flat s0 g0 0)) - (set-box! out1 - (flat s1 g1 0))))))))) + (f0 (ensure-flat (tp-force tp-t0))) + (f1 (ensure-flat (tp-force tp-t1))) + (fz (ensure-flat (tp-force tp-z))) + + (v0 (flat-store f0)) + (v1 (flat-store f1)) + (vz (flat-store fz)) + + (off0 (flat-offset f0)) + (off1 (flat-offset f1)) + (offz (flat-offset fz))) + (ext2-shapes + s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (flat:size-of + s0) + 0.0)) + (g1 (new-vec (flat:size-of + s1) + 0.0))) + (for ([iz (in-range + 0 + size-z + stride-z)]) + (let-values (((i0 i1) + (idxs + strides + iz + off0 + off1))) + (fᵈ g0 g1 v0 i0 + stride0 + v1 i1 + stride1 + vz + (+ offz iz) + stride-z))) + (set-box! out0 + (scalarize (flat s0 g0 0))) + (set-box! out1 + (scalarize (flat s1 g1 0))))))))) + +;; TODO: make sure that all functions being stored in tcomp structs should be +;; prims. Other functions should be inlined into tp-tcomp by passing all +;; parameters to these functions through the tcomp struct. +(define tp-d-ext2^ + (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (let* ((out0 (box 'uncalculated)) + (out1 (box 'uncalculated))) (values - (tpromise (tcomp-ext2-∇ out0 forcer) s0) - (tpromise (tcomp-ext2-∇ out1 forcer) s1))))) + (tpromise (tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 0) + (tp-shape tp-t0)) + (tpromise (tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 1) + (tp-shape tp-t1)))))) (define ensure-tpromise (λ (v) @@ -435,28 +472,6 @@ ((scalar? v) (tpromise (ensure-flat v) '())) (else v)))) -(define tp-ext2-∇ - (λ (f m n [shape-fn scalar-shape]) - (let ((tp-f - (λ (f tp-t tp-u tp-z) - (let-values (((tp-dt tp-du) - (tp-d-ext2^ f m n shape-fn - tp-t tp-u tp-z))) - (values (tp-scalarize tp-dt) - (tp-scalarize tp-du)))))) - (λ (tp-t tp-u tp-z) - (cond - ((flat:expects-preallocated? f) - (tp-f f tp-t tp-u tp-z)) - [else (let* ((t-shape (min-shape m (tp-shape tp-t))) - (u-shape (min-shape n (tp-shape tp-u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ - f t-shape u-shape out-shape))) - (tp-f flat-f - (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - (ensure-tpromise tp-z)))]))))) (define flat-gradient-maker2 (λ (f m n) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index cb7ccb9..fd1f6b6 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -32,7 +32,7 @@ (define test-premap-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) (define test-mapped-lt (tp-tmap add1 test-premap-lt)) - (check-true (flat? (tpromise-tensor test-premap-lt))) + (check-false (flat? (tpromise-tensor test-premap-lt))) (check-true (tcomp? (tpromise-tensor test-mapped-lt))) (check-equal? (flat-store (tp-force test-mapped-lt)) (vector 2 3 4 5 6 7)) (check-equal? (flat-shape (tp-force test-mapped-lt)) (flat-shape (tp-force test-premap-lt))) From cc0dffaaeff76657f56db7bb2dc69c3aeb4a4519 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 22 Jul 2023 10:00:46 -0400 Subject: [PATCH 57/83] [add-lazy]Fix test cases and add concat to lazy --- lazy.rkt | 7 +- lazy/autodiff/A-autodiff.rkt | 8 +- lazy/autodiff/B-prims.rkt | 8 +- lazy/ext-ops.rkt | 4 + lazy/ext-ops/K-concat.rkt | 75 ++++++++++++++ lazy/ext-ops/test/test-K-concat.rkt | 126 +++++++++++++++++++++++ lazy/tensors/0-lazy.rkt | 153 +++++++++++++++------------- lazy/tensors/test/test-0-lazy.rkt | 14 +-- set-impl.rkt | 3 +- 9 files changed, 302 insertions(+), 96 deletions(-) create mode 100644 lazy/ext-ops/K-concat.rkt create mode 100644 lazy/ext-ops/test/test-K-concat.rkt diff --git a/lazy.rkt b/lazy.rkt index 84c64cc..5861675 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -26,18 +26,19 @@ (d-exp exp) (d-log log) (d-expt expt) (d-sqrt sqrt) (d-sqr sqr) (d-sum sum) (d-abs abs) (d*-2-1 *-2-1) (d-argmax argmax) (d-max max) (d-sum-cols sum-cols) (d-correlate correlate) - (d-flatten flatten)) + (d-flatten flatten) + (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ - flatten-ρ + flatten-ρ concat-ρ +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 exp-0 log-0 sqrt-0 abs-0 rectify-0 - flatten-2 sum-1 argmax-1 max-1 + sum-1 argmax-1 max-1 flatten-2 concat-1-1 =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index 51c3f38..48fc839 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -21,15 +21,13 @@ (define ρ (λ (d) (cond - ((tensor? d) (tp-force d) d) - ((dual? d) (tp-force (vector-ref d 1)) - (scalarize (vector-ref d 1))) + ((dual? d) (scalarize (vector-ref d 1))) (else d)))) (define κ (λ (d) (cond - ((dual? d) (tp-force (vector-ref d 2))) + ((dual? d) (vector-ref d 2)) (else end-of-chain)))) (define scalar? @@ -75,7 +73,7 @@ (λ (y wrt) (let ((σ (∇σ y (hasheq)))) (map* (λ (d) - (hash-ref σ d 0.0)) + (tp-force (hash-ref σ d 0.0))) wrt)))) (define ∇σ diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 09ee621..4957cb6 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -29,8 +29,9 @@ (let ((ra (ρ da))) (dual (ρ-fn ra) (λ (d z σ) + ;; TODO: need force*-1 here while calling ∇-fn (let ((ga (∇-fn ra z))) - ((κ da) da ga #;(tp-force ga) σ))))))) + ((κ da) da #;ga (tp-force ga) σ))))))) (define prim2 (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) @@ -48,8 +49,9 @@ (rb (ρ db))) (dual (ρ-fn ra rb) (λ (d z σ) - (let-values (((ga gb) (∇-fn ra rb z))) - ;; TODO: remove the tp-force here + (let-values (((ga gb) (∇-fn ra rb z) + ;; TODO: define a force*-2 for this + #;(force*-2 z (lambda (z) (∇-fn ra rb z))))) (let ((σ-hat ((κ da) da (tp-force ga) σ))) ((κ db) db (tp-force gb) σ-hat)))))))) diff --git a/lazy/ext-ops.rkt b/lazy/ext-ops.rkt index 7bb112f..67709fa 100644 --- a/lazy/ext-ops.rkt +++ b/lazy/ext-ops.rkt @@ -8,6 +8,7 @@ (require "ext-ops/F-max.rkt") (require "ext-ops/G-correlate.rkt") (require "ext-ops/I-flatten.rkt") +(require "ext-ops/K-concat.rkt") (provide d+ d- d* d/ d-expt d-exp d-log d-abs @@ -34,3 +35,6 @@ (provide correlate-ρ d-correlate) (provide flatten-2 d-flatten flatten-ρ) + +(provide concat-1-1 d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/lazy/ext-ops/K-concat.rkt b/lazy/ext-ops/K-concat.rkt new file mode 100644 index 0000000..1f3ae9a --- /dev/null +++ b/lazy/ext-ops/K-concat.rkt @@ -0,0 +1,75 @@ +#lang racket + +(require (rename-in (only-in "../tensors.rkt" ext2-ρ tref tlen shape len ref) + (shape shape-ρ))) +(require (only-in "../autodiff.rkt" prim2 ext2 shape)) + +(define concat-shape + (λ (st su) + (cons (+ (ref st 0) (ref su 0)) + (cdr st)))) + +(define concat-base-ρ + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + (for ([i (in-range 0 stride-out)]) + (cond + ((< i stride0) + (vector-set! v-out (+ i-out i) (vector-ref v0 (+ i0 i)))) + (else + (vector-set! v-out (+ i-out i) (vector-ref v1 (+ i1 (- i stride0))))))))) + +(define concat-base-∇ + (λ (g0 g1 v0 i0 stride0 + v1 i1 stride1 + vz iz stride-z) + (for ([i (in-range 0 stride-z)]) + (cond + ((< i stride0) + (vector-set! g0 (+ i0 i) + (+ (vector-ref g0 (+ i0 i)) + (vector-ref vz (+ iz i))))) + (else + (vector-set! g1 (+ i1 (- i stride0)) + (+ (vector-ref g1 (+ i1 (- i stride0))) + (vector-ref vz (+ iz i))))))))) + +(define concat-base + (prim2 concat-base-ρ concat-base-∇ concat-shape)) + +(define d-concat-n + (λ (n) + (λ (t u) + (let ((st (shape t)) + (su (shape u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base n n) t u))))) + +(define concat-n-ρ + (λ (n) + (λ (t u) + (let ((st (shape-ρ t)) + (su (shape-ρ u))) + (ensure-compatible-shapes n st su) + ((ext2 concat-base-ρ n n concat-shape) t u))))) + +(define ensure-compatible-shapes + (λ (n st su) + (let ((rt (len st)) + (ru (len su))) + ;; The shape of the tensor of rank r at rank n-1 + ;; is given by (drop st (+ 1 (- r n))) + (when (not (equal? (drop st (+ 1 (- rt n))) (drop su (+ 1 (- ru n))))) + (error 'concat "Incompatible concat shapes: ~a and ~a at last ~a dimensions" + st su n))))) + +(define d-concat (d-concat-n 1)) +(define concat-ρ (concat-n-ρ 1)) +(define concat-1-1 concat-base) + +(include "test/test-K-concat.rkt") + +(provide concat-1-1 + d-concat concat-ρ + d-concat-n concat-n-ρ) diff --git a/lazy/ext-ops/test/test-K-concat.rkt b/lazy/ext-ops/test/test-K-concat.rkt new file mode 100644 index 0000000..b427ced --- /dev/null +++ b/lazy/ext-ops/test/test-K-concat.rkt @@ -0,0 +1,126 @@ +(module+ test + (require rackunit) + (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../autodiff.rkt" check-ρ-∇ check-dual-equal?)) + (require (only-in "A-scalar-ops.rkt" d*)) + + (define r2-t1 (tensor (tensor 3.0 4.0) (tensor 5.0 6.0))) + (define r1-t2 (tensor 5.0 6.0 7.0)) + (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + + (check-dual-equal? + (d-concat r2-t1 r1-t2) + (tensor (tensor 3.0 4.0 5.0 6.0 7.0) + (tensor 5.0 6.0 5.0 6.0 7.0))) + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (d-concat t1 t2))) r2-t1 r1-t2 r1-t1) + (tensor (tensor 9.0 16.0 25.0 36.0 49.0) + (tensor 15.0 24.0 25.0 36.0 49.0)) + (list (tensor (tensor 3.0 4.0) (tensor 3.0 4.0)) + (tensor 10.0 12.0 14.0) + (tensor 8.0 10.0 10.0 12.0 14.0))) + (define r3-t1 + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0)))) + + + (define r2-t2 + (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0))) + + (define r1-t3 + (tensor 0.5 0.5)) + + (define concat-2 (d-concat-n 2)) + + (check-dual-equal? + (concat-2 r3-t1 r2-t2) + (tensor (tensor (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 9.0 10.0) + (tensor 11.0 12.0) + (tensor 13.0 14.0) + (tensor 15.0 16.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)) + + (tensor (tensor 17.0 18.0) + (tensor 19.0 20.0) + (tensor 21.0 22.0) + (tensor 23.0 24.0) + (tensor 1.0 2.0) + (tensor 3.0 4.0) + (tensor 5.0 6.0) + (tensor 7.0 8.0)))) + + + (check-ρ-∇ ((λ (t1 t2 t3) (d* t3 (concat-2 t1 t2))) r3-t1 r2-t2 r1-t3) + (tensor (tensor (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 4.5 5.0) + (tensor 5.5 6.0) + (tensor 6.5 7.0) + (tensor 7.5 8.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0)) + + (tensor (tensor 8.5 9.0) + (tensor 9.5 10.0) + (tensor 10.5 11.0) + (tensor 11.5 12.0) + (tensor 0.5 1.0) + (tensor 1.5 2.0) + (tensor 2.5 3.0) + (tensor 3.5 4.0))) + (list + (tensor (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5)) + (tensor (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5) + (tensor 0.5 0.5))) + + (tensor (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5) + (tensor 1.5 1.5)) + + (tensor 192.0 216.0)))) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 3cb97ca..bcfe28e 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -5,10 +5,31 @@ ;; tensor computations (struct tcomp ()) +#; +(: lst (U (Listof tpromise) (Listof Number))) (struct tcomp-list->tpromise-list tcomp (lst) #:transparent) -(struct tcomp-tp-map tcomp (f tp) #:transparent) +#; +(: s (Listof Natural)) ;; non-empty +#; +(: f (-> (Listof Natural) Number)) (struct tcomp-build-tpromise tcomp (s f) #:transparent) -(struct tcomp-tp-trefs tcomp (forced b) #:transparent) +#; +(: tp tpromise) +#; +(: i Natural) +(struct tcomp-tp-tref tcomp (tp i) #:transparent) +#; +(: tp tpromise) +#; +(: i (Listof Natural)) +(struct tcomp-tp-trefs tcomp (tp b) #:transparent) +;;TODO: Use functional->preallocated-* to use non-mutated/functional types for +;;the ext base functions +#; +(: fᵈ (U (-> Number Number (Values Number Number)) + (-> (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural)))) (struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) #:transparent) (struct tcomp-ext1-∇-prealloc tcomp (tp zp f m shape-fn) #:transparent) @@ -20,6 +41,8 @@ (struct tcomp-ext1-ρ-prealloc tcomp (f m shape-fn tp) #:transparent) (struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) +#; +(: args (U (Listof tpromise) (Listof Number))) (struct tcomp-tensor tcomp (args) #:transparent) (struct tpromise ((tensor #:mutable) shape) @@ -40,13 +63,21 @@ (values tensor shape)) #:transparent) +#; +(: scalar? (-> Any Boolean)) (define scalar? number?) +#; +(: tensor (case-> (-> tpromise * tpromise) + (-> Number * tpromise))) (define tensor (λ args - (ensure-shape args) - (let ((inner-flat (tensor-inner-flat args)) - ) + (unless (ensure-shape args) + (error 'tensor + "Mismatched shapes: ~a~%" + args)) + + (let ((inner-flat (tensor-inner-flat args))) (cond ((flat? inner-flat) (tpromise inner-flat (flat-shape inner-flat))) @@ -55,37 +86,41 @@ (outer (length args)) (new-shape (cons outer inner-shape))) (tpromise inner-flat new-shape))))))) - +#; +(: tensor-inner-flat (-> (U (Listof tpromise) (Listof Number)) + (U flat tcomp))) (define tensor-inner-flat (λ (args) (cond [(number? (car args)) (apply flat:tensor args)] [else (tcomp-tensor args)]))) +#; +(: ensure-shape (-> (U (Listof tpromise) (Listof Number)) Boolean)) (define ensure-shape (λ (args) - (unless (and (not (null? args)) - (cond - ((number? (car args)) - (andmap number? (cdr args))) - ((tpromise? (car args)) - (let ((s (tp-shape (car args)))) - (andmap (λ (t) - (and (tpromise? t) - (equal? (tp-shape t) s))) - (cdr args)))) - (else #f))) - (error 'tensor - "Mismatched shapes: ~a~%" - args)))) - + (and (not (null? args)) + (cond + ((number? (car args)) + (andmap number? (cdr args))) + ((tpromise? (car args)) + (let ((s (tp-shape (car args)))) + (andmap (λ (t) + (and (tpromise? t) + (equal? (tp-shape t) s))) + (cdr args)))) + (else #f))))) + +#; +(: ensure-flat (-> (U flat Number) flat)) (define ensure-flat (λ (v) (cond ((scalar? v) (flat '() (vec v) 0)) (else v)))) -;(-> tpromise (U flat scalar)) +#; +(: tp-force (-> tpromise (U flat Number))) (define tp-force (lambda (tp (print? #f)) (when print? @@ -107,23 +142,20 @@ res] [else tp])))) +#; +(: tcomp-force (-> tcomp (U flat Number))) (define tcomp-force (λ (tc) (match tc [(tcomp-list->tpromise-list lst) (flat:list->tensor (map (λ (l) (tp-force l #f)) lst))] - [(tcomp-tp-map f tp) - (let* ([flat-vec (tp-force tp #f)] - [store (flat-store flat-vec)] - [shape (flat-shape flat-vec)] - [offset (flat-offset flat-vec)]) - (flat shape (vector-map f store) - offset))] [(tcomp-build-tpromise s f) (flat:build-tensor s f)] - [(tcomp-tp-trefs forced b) - (flat:trefs forced b)] + [(tcomp-tp-tref tp i) + (flat:tref (tp-force tp) i)] + [(tcomp-tp-trefs tp b) + (flat:trefs (tp-force tp) b)] [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let* ([b (if (zero? i) out0 out1)] [v (unbox b)]) @@ -133,7 +165,7 @@ (unbox b)) (else v)))] [(tcomp-ext1-∇-prealloc tp zp f m shape-fn) - (scalarize + (tp-scalarize (flat-ext1-∇ f m shape-fn (ensure-flat (tp-force tp)) (ensure-flat (tp-force zp))))] @@ -144,11 +176,11 @@ (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] + (tp-scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] [(tcomp-ext2-ρ-scalar f tp-t tp-u) (f (tp-force tp-t) (tp-force tp-t))] [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) - (scalarize + (tp-scalarize (flat-ext2-ρ f m n shape-fn (ensure-flat (tp-force tp-t)) (ensure-flat (tp-force tp-u))))] @@ -159,19 +191,19 @@ (u-shape (min-shape n (flat-shape u))) (out-shape (shape-fn t-shape u-shape)) (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) - (scalarize + (tp-scalarize (flat-ext2-ρ flat-f m n shape-fn t u))))] [(tcomp-ext1-ρ-scalar f tp) (f (tp-force tp))] [(tcomp-ext1-ρ-prealloc f m shape-fn tp) - (scalarize (flat-ext1-ρ f m shape-fn (ensure-flat (tp-force tp))))] + (tp-scalarize (flat-ext1-ρ f m shape-fn (ensure-flat (tp-force tp))))] [(tcomp-ext1-ρ f m shape-fn tp) (let ([t (ensure-flat (tp-force tp #f))]) (let* ((in-shape (flat-shape t)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) - (scalarize + (tp-scalarize (flat-ext1-ρ flat-f m shape-fn t))))] [(tcomp-reshape s tp) (let ([t (tp-force tp #f)]) @@ -180,11 +212,6 @@ (merge-flats (map tp-force args))]))) -;; TODO: This can also be made lazy -(define tp-force-ref - (λ (tp i) - (flat:tref (tp-force tp) i))) - (define bounded-idx*^ (λ (shape idx*) (match `(,shape ,idx*) @@ -202,12 +229,9 @@ (define tp-tref (lambda (tp i) (cond - [(and (bounded-idx*? tp (list i)) - (flat? (tp-force-ref tp i))) - (tpromise (tp-force-ref tp i) - (flat-shape (tp-force-ref tp i)))] [(bounded-idx*? tp (list i)) - (tp-force-ref tp i)] + (tpromise (tcomp-tp-tref tp i) + (cdr (tpromise-shape tp)))] [else (error 'exn:tp-tref (string-append "Index out of bounds. ~a " @@ -236,13 +260,6 @@ . ,(tp-shape (car lst))))]))) -(define tp-tmap - (λ (f tp) - (struct-copy - tpromise tp - (tensor - (tcomp-tp-map f tp))))) - (define build-tpromise (λ (s f) (tpromise (tcomp-build-tpromise s f) s))) @@ -257,10 +274,9 @@ (error 'tp-trefs "An index was out of bounds")] [else - (let ([forced (tp-force tp)]) - (tpromise (tcomp-tp-trefs forced b) - `(,(length b) - . ,(cdr (flat-shape forced)))))]))) + (tpromise (tcomp-tp-trefs tp b) + `(,(length b) + . ,(cdr (tpromise-shape tp))))]))) (define tp-ext1-ρ (λ (f m [shape-fn scalar-shape]) @@ -322,18 +338,13 @@ (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) -(define scalarize - (λ (t) - (cond - ((and (flat? t) (null? (flat-shape t))) - (vref (flat-store t) 0)) - (else t)))) - (define tp-scalarize (λ (tp) (cond [(and (tpromise? tp) (null? (tpromise-shape tp))) - (scalarize (tp-force tp))] + (tp-scalarize (tp-force tp))] + [(and (flat? tp) (null? (flat-shape tp))) + (vref (flat-store tp) 0)] [else tp]))) (define scalar-shape @@ -449,13 +460,11 @@ (+ offz iz) stride-z))) (set-box! out0 - (scalarize (flat s0 g0 0))) + (tp-scalarize (flat s0 g0 0))) (set-box! out1 - (scalarize (flat s1 g1 0))))))))) + (tp-scalarize (flat s1 g1 0))))))))) -;; TODO: make sure that all functions being stored in tcomp structs should be -;; prims. Other functions should be inlined into tp-tcomp by passing all -;; parameters to these functions through the tcomp struct. +;; TODO: Create a lazy-apply-2 that does this more generally (define tp-d-ext2^ (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((out0 (box 'uncalculated)) @@ -516,7 +525,7 @@ (define tensor? (lambda (tp) - (or (tpromise? tp) (scalar? tp)))) + (or (tpromise? tp) (flat? tp) (scalar? tp)))) (include "test/test-0-lazy.rkt") diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index fd1f6b6..c395b0a 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -8,11 +8,11 @@ (check-exn exn:fail? (λ () (tensor test-lt 4))) (check-exn exn:fail? (λ () (tensor 4 test-lt))) - (check-equal? (tp-tref test-lt 2) 3) + (check-equal? (tp-force (tp-tref test-lt 2)) 3) (check-exn exn:fail? (λ () (tp-tref test-lt 5))) (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) - (check-equal? (tp-tref (tp-tref test-nested-lt 0) 2) 3) + (check-equal? (tp-force (tp-tref (tp-tref test-nested-lt 0) 2)) 3) (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0)) 3) (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2)) 3) (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) @@ -30,16 +30,6 @@ (check-false (bounded-idx*? test-nested-lt-from-list (list 1 3))) (check-false (bounded-idx*? test-nested-lt-from-list (list 1 1 0))) - (define test-premap-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) - (define test-mapped-lt (tp-tmap add1 test-premap-lt)) - (check-false (flat? (tpromise-tensor test-premap-lt))) - (check-true (tcomp? (tpromise-tensor test-mapped-lt))) - (check-equal? (flat-store (tp-force test-mapped-lt)) (vector 2 3 4 5 6 7)) - (check-equal? (flat-shape (tp-force test-mapped-lt)) (flat-shape (tp-force test-premap-lt))) - (check-equal? (flat-offset (tp-force test-mapped-lt)) (flat-offset (tp-force test-premap-lt))) - (check-true (flat? (tpromise-tensor test-premap-lt))) - (check-true (flat? (tpromise-tensor test-mapped-lt))) - (define test-build-shape '(4 3)) (define test-built-tensor (build-tpromise test-build-shape (λ (i) diff --git a/set-impl.rkt b/set-impl.rkt index ded4d92..469e7b5 100644 --- a/set-impl.rkt +++ b/set-impl.rkt @@ -11,7 +11,8 @@ nested-tensors flat-tensors uniform-tensors - accelerated-tensors))) + accelerated-tensors + lazy))) (error "Unknown implementation: ~a~%" impl)) (setup #:collections (list (list "malt")) #:clean? #t) (write-implementation-to-config-file impl) From 32f14d520feea74ddee509d715d5691d54251a75 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 22 Jul 2023 13:12:41 -0400 Subject: [PATCH 58/83] [add-lazy]Implement force* and fix provides --- lazy.rkt | 2 +- lazy/autodiff.rkt | 1 + lazy/autodiff/B-prims.rkt | 14 +++++++++++--- lazy/autodiff/C-dualized-tensor-ops.rkt | 6 +++++- lazy/no-duals-no-overrides.rkt | 5 +++-- lazy/no-duals.rkt | 7 ++++--- lazy/no-overrides.rkt | 14 ++++++++++---- lazy/tensors.rkt | 2 ++ lazy/tensors/0-lazy.rkt | 13 ++++++++++++- 9 files changed, 49 insertions(+), 15 deletions(-) diff --git a/lazy.rkt b/lazy.rkt index 5861675..59b0e9a 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -2,7 +2,7 @@ (require (except-in "lazy/tensors.rkt" - rank shape reshape trefs tensor? tlen ref refr)) + rank shape reshape tref trefs tensor? tlen ref refr)) (require "lazy/autodiff.rkt") (require "lazy/ext-ops.rkt") diff --git a/lazy/autodiff.rkt b/lazy/autodiff.rkt index 93d4384..ed89acd 100644 --- a/lazy/autodiff.rkt +++ b/lazy/autodiff.rkt @@ -11,6 +11,7 @@ (provide (rename-out (d-rank rank) (d-shape shape) (d-reshape reshape) + (d-tref tref) (d-trefs trefs) (d-tensor? tensor?) (d-tlen tlen) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 4957cb6..b7fc2fc 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -30,8 +30,11 @@ (dual (ρ-fn ra) (λ (d z σ) ;; TODO: need force*-1 here while calling ∇-fn - (let ((ga (∇-fn ra z))) - ((κ da) da #;ga (tp-force ga) σ))))))) + #;(let ((ga (∇-fn ra z))) + ((κ da) da #;ga (tp-force ga) σ)) + (force*1 (∇-fn ra z) + (λ (ga) + ((κ da) da ga σ)))))))) (define prim2 (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) @@ -49,11 +52,16 @@ (rb (ρ db))) (dual (ρ-fn ra rb) (λ (d z σ) + #; (let-values (((ga gb) (∇-fn ra rb z) ;; TODO: define a force*-2 for this #;(force*-2 z (lambda (z) (∇-fn ra rb z))))) (let ((σ-hat ((κ da) da (tp-force ga) σ))) - ((κ db) db (tp-force gb) σ-hat)))))))) + ((κ db) db (tp-force gb) σ-hat))) + (force*2 (λ () (∇-fn ra rb z)) + (λ (ga gb) + (let ((σ-hat ((κ da) da ga σ))) + ((κ db) db gb σ-hat))))))))) ;;---------------------------- ;; Dualized tensor op creators diff --git a/lazy/autodiff/C-dualized-tensor-ops.rkt b/lazy/autodiff/C-dualized-tensor-ops.rkt index 982d20b..988bc80 100644 --- a/lazy/autodiff/C-dualized-tensor-ops.rkt +++ b/lazy/autodiff/C-dualized-tensor-ops.rkt @@ -24,6 +24,10 @@ (κ t))) (else (reshape s t))))) +(define d-tref + (λ (t i) + (tref (ρ t) i))) + (define d-trefs (λ (t b) (trefs (ρ t) b))) @@ -44,4 +48,4 @@ (λ (l i) (refr l (ρ i)))) -(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr) +(provide d-rank d-shape d-reshape d-trefs d-tensor? d-tlen d-ref d-refr d-tref) diff --git a/lazy/no-duals-no-overrides.rkt b/lazy/no-duals-no-overrides.rkt index 6163435..ac07a7a 100644 --- a/lazy/no-duals-no-overrides.rkt +++ b/lazy/no-duals-no-overrides.rkt @@ -14,15 +14,16 @@ tref tlen list->tensor tensor build-tensor - ext1-ρ ext2-ρ ext1-∇ ext2-∇ + ext1-ρ ext2-ρ scalar? tensor? rank shape reshape trefs ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ + concat-ρ concat-n-ρ flatten-ρ =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-duals.rkt b/lazy/no-duals.rkt index 062be42..927c8c7 100644 --- a/lazy/no-duals.rkt +++ b/lazy/no-duals.rkt @@ -14,15 +14,16 @@ tref tlen list->tensor tensor build-tensor - ext1-ρ ext2-ρ ext1-∇ ext2-∇ + ext1-ρ ext2-ρ scalar? tensor? rank shape reshape trefs ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) - (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate)) + (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) + (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/no-overrides.rkt b/lazy/no-overrides.rkt index 22fb210..35dcbdd 100644 --- a/lazy/no-overrides.rkt +++ b/lazy/no-overrides.rkt @@ -2,7 +2,7 @@ (require (except-in "tensors.rkt" - rank shape reshape trefs tensor? tlen ref refr)) + rank shape reshape trefs tref tensor? tlen ref refr)) (require "autodiff.rkt") (require "ext-ops.rkt") @@ -21,17 +21,23 @@ scalar? tensor? rank shape reshape trefs trace-print check-dual-equal? check-ρ-∇ + make-printable + + +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 + flatten-2 concat-1-1 d+ d- d* d/ d-rectify - d-exp d-log d-expt d-sqrt + d-exp d-log d-expt d-sqrt d-sqr d-sum d-abs d*-2-1 d-argmax d-max d-sum-cols d-correlate + d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ - + flatten-ρ concat-ρ concat-n-ρ =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index 9b06ad2..e86ab07 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -16,3 +16,5 @@ ;; These will get overriden by duals (provide tensor?) (provide rank shape reshape size-of) + +(provide force*1 force*2) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index bcfe28e..13e5b8d 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -24,7 +24,7 @@ (: i (Listof Natural)) (struct tcomp-tp-trefs tcomp (tp b) #:transparent) ;;TODO: Use functional->preallocated-* to use non-mutated/functional types for -;;the ext base functions +;; the ext base functions #; (: fᵈ (U (-> Number Number (Values Number Number)) (-> (Vector Number) Natural (Listof Natural) @@ -527,6 +527,15 @@ (lambda (tp) (or (tpromise? tp) (flat? tp) (scalar? tp)))) +(define force*1 + (λ (t f) + (f (tp-force t)))) + +(define force*2 + (λ (ts f) + (let-values (((t1 t2) (ts))) + (f (tp-force t1) (tp-force t2))))) + (include "test/test-0-lazy.rkt") (provide start-vector-manager vector-manager-report) @@ -559,3 +568,5 @@ (tp-shape shape) (tp-reshape reshape) (flat:size-of size-of))) + +(provide force*1 force*2) From 28a107494e10005717bf9b02eae99188e25fb142 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 26 Jul 2023 21:32:26 -0400 Subject: [PATCH 59/83] [add-lazy]Fix test cases --- lazy/tensors/test/test-0-lazy.rkt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index c395b0a..09da45a 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -13,8 +13,8 @@ (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) (check-equal? (tp-force (tp-tref (tp-tref test-nested-lt 0) 2)) 3) - (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0)) 3) - (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2)) 3) + (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0))) + (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2))) (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) (check-equal? (tp-tlen test-lt) 3) @@ -46,7 +46,7 @@ (check-true (tcomp? (tpromise-tensor test-tp-trefs))) (check-equal? (tpromise-shape test-tp-trefs) (flat-shape (tp-force test-tp-trefs))) (check-equal? (flat-store (tp-force test-tp-trefs)) (vector 0 1 2 6 7 8)) - (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4))) 3) + (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4)))) (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) From e751c9fbed3a9c85608297ec3d0d5169b70a5d9c Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 21:12:47 -0400 Subject: [PATCH 60/83] [add-lazy]Bugfixes and refactoring --- lazy/autodiff/A-autodiff.rkt | 2 +- lazy/tensors/0-lazy.rkt | 54 +++++++++++++++++++++++++----------- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index 48fc839..b4210f9 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -22,7 +22,7 @@ (λ (d) (cond ((dual? d) (scalarize (vector-ref d 1))) - (else d)))) + (else (scalarize d))))) (define κ (λ (d) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 13e5b8d..d3e78b1 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -7,24 +7,22 @@ (struct tcomp ()) #; (: lst (U (Listof tpromise) (Listof Number))) -(struct tcomp-list->tpromise-list tcomp (lst) #:transparent) +(struct tcomp-list->tensor tcomp (lst) #:transparent) #; (: s (Listof Natural)) ;; non-empty #; (: f (-> (Listof Natural) Number)) -(struct tcomp-build-tpromise tcomp (s f) #:transparent) +(struct tcomp-build-tensor tcomp (s f) #:transparent) #; (: tp tpromise) #; (: i Natural) -(struct tcomp-tp-tref tcomp (tp i) #:transparent) +(struct tcomp-tref tcomp (tp i) #:transparent) #; (: tp tpromise) #; (: i (Listof Natural)) -(struct tcomp-tp-trefs tcomp (tp b) #:transparent) -;;TODO: Use functional->preallocated-* to use non-mutated/functional types for -;; the ext base functions +(struct tcomp-trefs tcomp (tp b) #:transparent) #; (: fᵈ (U (-> Number Number (Values Number Number)) (-> (Vector Number) Natural (Listof Natural) @@ -121,6 +119,31 @@ #; (: tp-force (-> tpromise (U flat Number))) +#; +(define force/eval + (lambda (delayed-expr/env) + (let-values (((instructions env) + (compile-delayed-expr delayed-expr/env)))) + (run-instructions instructions env))) + +;; run-instructions and compile-delayed-expr +;; Phase 1 : instructions are scheme and run-instructions is basically eval +;; -- determine the separation between the compiler and run-time-system. +;; -- 1 week +;; Phase 2 : instructions are C and run-instructions is FFI + C. +;; -- write the runtime-system in C +;; -- 3 weeks +;; Phase 3 : instructions are OpenCL and run-instructions is FFI+ C +;; -- write the runtime-system in openCL. +;; -- 3 weeks +;; -- Distribution across machines --> Ph. D. Thesis. +;; Phase 4 : instructions are SPIR-V and run-instructions is a SPIR-V driver. +;; -- write the runtime-system in openCL with a SPIR-V target (?) +;; -- 6 weeks +;; Phase 5 : instructions are custom, and runtime system is on FPGA. +;; -- build VHDL blocks for custom instructions. +;; -- Ph. D. Thesis. + (define tp-force (lambda (tp (print? #f)) (when print? @@ -147,14 +170,14 @@ (define tcomp-force (λ (tc) (match tc - [(tcomp-list->tpromise-list lst) + [(tcomp-list->tensor lst) (flat:list->tensor (map (λ (l) (tp-force l #f)) lst))] - [(tcomp-build-tpromise s f) + [(tcomp-build-tensor s f) (flat:build-tensor s f)] - [(tcomp-tp-tref tp i) + [(tcomp-tref tp i) (flat:tref (tp-force tp) i)] - [(tcomp-tp-trefs tp b) + [(tcomp-trefs tp b) (flat:trefs (tp-force tp) b)] [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let* ([b (if (zero? i) out0 out1)] @@ -178,7 +201,7 @@ (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) (tp-scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (f (tp-force tp-t) (tp-force tp-t))] + (f (tp-force tp-t) (tp-force tp-u))] [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) (tp-scalarize (flat-ext2-ρ f m n shape-fn @@ -230,7 +253,7 @@ (lambda (tp i) (cond [(bounded-idx*? tp (list i)) - (tpromise (tcomp-tp-tref tp i) + (tpromise (tcomp-tref tp i) (cdr (tpromise-shape tp)))] [else (error 'exn:tp-tref (string-append @@ -255,14 +278,14 @@ [(null? lst) (error 'list->ltensor "No elements found")] [else - (tpromise (tcomp-list->tpromise-list lst) + (tpromise (tcomp-list->tensor lst) `(,(length lst) . ,(tp-shape (car lst))))]))) (define build-tpromise (λ (s f) - (tpromise (tcomp-build-tpromise s f) s))) + (tpromise (tcomp-build-tensor s f) s))) (define tp-trefs (λ (tp b) @@ -274,7 +297,7 @@ (error 'tp-trefs "An index was out of bounds")] [else - (tpromise (tcomp-tp-trefs tp b) + (tpromise (tcomp-trefs tp b) `(,(length b) . ,(cdr (tpromise-shape tp))))]))) @@ -464,7 +487,6 @@ (set-box! out1 (tp-scalarize (flat s1 g1 0))))))))) -;; TODO: Create a lazy-apply-2 that does this more generally (define tp-d-ext2^ (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((out0 (box 'uncalculated)) From 770654a0122762c2bce3651125d3e7e1c59f8f12 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 9 Aug 2023 16:17:50 -0400 Subject: [PATCH 61/83] [add-lazy]Separate run and compile phases --- flat-tensors/ext-impl.rkt | 3 +- flat-tensors/tensors/D-extend.rkt | 2 +- lazy/autodiff/A-autodiff.rkt | 2 +- lazy/autodiff/D-test-helpers.rkt | 8 +- lazy/autodiff/E-print.rkt | 2 +- lazy/tensors.rkt | 2 +- lazy/tensors/0-lazy.rkt | 199 ++++++++++++++---------------- lazy/tensors/A-equality.rkt | 2 +- lazy/tensors/test/test-0-lazy.rkt | 42 +++---- malted/test/test-O-init.rkt | 8 +- 10 files changed, 127 insertions(+), 143 deletions(-) diff --git a/flat-tensors/ext-impl.rkt b/flat-tensors/ext-impl.rkt index 2530b4f..8a47635 100644 --- a/flat-tensors/ext-impl.rkt +++ b/flat-tensors/ext-impl.rkt @@ -14,7 +14,8 @@ functional->preallocated-1-∇ functional->preallocated-2-ρ functional->preallocated-2-∇ - idxs)) + idxs + scalarize)) (require (only-in "autodiff/E-print.rkt" make-printable-flat fake-tensor)) diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index edca5fa..6670747 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -389,4 +389,4 @@ functional->preallocated-1-ρ functional->preallocated-1-∇ functional->preallocated-2-ρ functional->preallocated-2-∇ merge-shapes min-shape ext2-shapes idxs - flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ) + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index b4210f9..df57f32 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -73,7 +73,7 @@ (λ (y wrt) (let ((σ (∇σ y (hasheq)))) (map* (λ (d) - (tp-force (hash-ref σ d 0.0))) + (force/eval (hash-ref σ d 0.0))) wrt)))) (define ∇σ diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index f504e0e..6d11570 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -1,19 +1,19 @@ #lang racket (require "../tensors.rkt") -(require (only-in "../tensors/0-lazy.rkt" tp-force)) +(require (only-in "../tensors/0-lazy.rkt" force/eval)) (require "A-autodiff.ss") (require rackunit) (define forced-ρ (λ (d) - (tp-force (ρ d)))) + (force/eval (ρ d)))) (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) - (let* ((y (tp-force (apply fn args))) - (g (tp-force (apply (∇¹ fn) args))) + (let* ((y (force/eval (apply fn args))) + (g (force/eval (apply (∇¹ fn) args))) (ans-ρ (ρ ans))) (cond ((and (equal-wt? ans-ρ (ρ y)) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 66494d4..2dd32ae 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -11,7 +11,7 @@ (cond ((dual? y) (make-printable (ρ y))) ((tpromise? y) - (make-printable (tp-force y) max-length)) + (make-printable (force/eval y) max-length)) ((flat? y) (make-printable-flat y max-length)) ((list? y) (map (λ (le) (make-printable le max-length)) y)) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index e86ab07..dbbf905 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -11,7 +11,7 @@ (provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) -(provide tp-force scalarize) +(provide force/eval scalarize) ;; These will get overriden by duals (provide tensor?) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index d3e78b1..e303e42 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -117,123 +117,103 @@ ((scalar? v) (flat '() (vec v) 0)) (else v)))) -#; -(: tp-force (-> tpromise (U flat Number))) -#; +(define-namespace-anchor a) + +(define run-instrs + (lambda (instrs) + (let ([env (namespace-anchor->namespace a)]) + (eval instrs env)))) + (define force/eval - (lambda (delayed-expr/env) - (let-values (((instructions env) - (compile-delayed-expr delayed-expr/env)))) - (run-instructions instructions env))) - -;; run-instructions and compile-delayed-expr -;; Phase 1 : instructions are scheme and run-instructions is basically eval -;; -- determine the separation between the compiler and run-time-system. -;; -- 1 week -;; Phase 2 : instructions are C and run-instructions is FFI + C. -;; -- write the runtime-system in C -;; -- 3 weeks -;; Phase 3 : instructions are OpenCL and run-instructions is FFI+ C -;; -- write the runtime-system in openCL. -;; -- 3 weeks -;; -- Distribution across machines --> Ph. D. Thesis. -;; Phase 4 : instructions are SPIR-V and run-instructions is a SPIR-V driver. -;; -- write the runtime-system in openCL with a SPIR-V target (?) -;; -- 6 weeks -;; Phase 5 : instructions are custom, and runtime system is on FPGA. -;; -- build VHDL blocks for custom instructions. -;; -- Ph. D. Thesis. - -(define tp-force (lambda (tp (print? #f)) (when print? (printf "~n####PP tensor: ") (pretty-print tp)) - (let ([res - (match tp - [(tpromise t-tcomp _) - #:when (tcomp? t-tcomp) - (tcomp-force t-tcomp)] - [(tpromise t _) - #:when (or (flat? t) (scalar? t)) t] - - ;; NOTE: This case runs when we use tp-scalarize to turn - ;; the tensor to a scalar - [_ #f])]) - (cond - [res (set-tpromise-tensor! tp res) - res] - [else tp])))) + (match tp + [(tpromise t-tcomp _) + #:when (tcomp? t-tcomp) + (let* ((instrs (compile-expr t-tcomp '())) + (res (run-instrs instrs))) + (set-tpromise-tensor! tp res) + res)] + [(tpromise t _) + #:when (or (flat? t) (scalar? t)) t] + ;; NOTE: This case runs when we use tp-scalarize to turn + ;; the tensor to a scalar + (else tp)))) #; (: tcomp-force (-> tcomp (U flat Number))) -(define tcomp-force - (λ (tc) +(define compile-expr + (λ (tc t-env) (match tc [(tcomp-list->tensor lst) - (flat:list->tensor - (map (λ (l) (tp-force l #f)) lst))] + `(flat:list->tensor + (map (λ (l) (force/eval l #f)) ',lst))] [(tcomp-build-tensor s f) - (flat:build-tensor s f)] + `(flat:build-tensor ',s ,f)] [(tcomp-tref tp i) - (flat:tref (tp-force tp) i)] + `(flat:tref (force/eval ,tp) ,i)] [(tcomp-trefs tp b) - (flat:trefs (tp-force tp) b)] + `(flat:trefs (force/eval ,tp) ',b)] [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let* ([b (if (zero? i) out0 out1)] - [v (unbox b)]) - (cond - ((eqv? v 'uncalculated) - (ext2-∇-forcer fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1) - (unbox b)) - (else v)))] + `(let* ([b (if (zero? ,i) ,out0 ,out1)] + [v (ext2-∇-result-res b)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn + (force/eval ,tp-t0) + (force/eval ,tp-t1) + (force/eval ,tp-z) + ,out0 ,out1) + (ext2-∇-result-res b)) + (else v)))] [(tcomp-ext1-∇-prealloc tp zp f m shape-fn) - (tp-scalarize - (flat-ext1-∇ f m shape-fn - (ensure-flat (tp-force tp)) - (ensure-flat (tp-force zp))))] + `(scalarize + (flat-ext1-∇ ,f ,m ,shape-fn + (ensure-flat (force/eval ,tp)) + (ensure-flat (force/eval ,zp))))] [(tcomp-ext1-∇ tp zp f m shape-fn) - (let ([t (ensure-flat (tp-force tp #f))] - [z (ensure-flat (tp-force zp #f))]) - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tp-scalarize (flat-ext1-∇ flat-f m shape-fn t z))))] + `(let ([t (ensure-flat (force/eval ,tp #f))] + [z (ensure-flat (force/eval ,zp #f))]) + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape ,m in-shape)) + (out-shape (,shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ ,f base-shape out-shape))) + (scalarize (flat-ext1-∇ flat-f ,m ,shape-fn t z))))] [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (f (tp-force tp-t) (tp-force tp-u))] + `(,f (force/eval ,tp-t) (force/eval ,tp-u))] [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) - (tp-scalarize - (flat-ext2-ρ f m n shape-fn - (ensure-flat (tp-force tp-t)) - (ensure-flat (tp-force tp-u))))] + `(scalarize + (flat-ext2-ρ ,f ,m ,n ,shape-fn + (ensure-flat (force/eval ,tp-t)) + (ensure-flat (force/eval ,tp-u))))] [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (let ([t (ensure-flat (tp-force tp-t #f))] - [u (ensure-flat (tp-force tp-u #f))]) - (let* ((t-shape (min-shape m (flat-shape t))) - (u-shape (min-shape n (flat-shape u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ f t-shape u-shape out-shape))) - (tp-scalarize - (flat-ext2-ρ flat-f m n shape-fn t u))))] + `(let ([t (ensure-flat (force/eval ,tp-t #f))] + [u (ensure-flat (force/eval ,tp-u #f))]) + (let* ((t-shape (min-shape ,m (flat-shape t))) + (u-shape (min-shape ,n (flat-shape u))) + (out-shape (,shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ ,f t-shape u-shape out-shape))) + (scalarize + (flat-ext2-ρ flat-f ,m ,n ,shape-fn t u))))] [(tcomp-ext1-ρ-scalar f tp) - (f (tp-force tp))] + `(,f (force/eval ,tp))] [(tcomp-ext1-ρ-prealloc f m shape-fn tp) - (tp-scalarize (flat-ext1-ρ f m shape-fn (ensure-flat (tp-force tp))))] + `(scalarize (flat-ext1-ρ ,f ,m ,shape-fn (ensure-flat (force/eval ,tp))))] [(tcomp-ext1-ρ f m shape-fn tp) - (let ([t (ensure-flat (tp-force tp #f))]) - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) - (tp-scalarize - (flat-ext1-ρ flat-f m shape-fn t))))] + `(let ([t (ensure-flat (force/eval ,tp #f))]) + (let* ((in-shape (flat-shape t)) + (base-shape (min-shape ,m in-shape)) + (out-shape (,shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ ,f base-shape out-shape))) + (scalarize + (flat-ext1-ρ flat-f ,m ,shape-fn t))))] [(tcomp-reshape s tp) - (let ([t (tp-force tp #f)]) - (flat s (flat-store t) (flat-offset t)))] + `(let ([t (force/eval ,tp #f)]) + (flat ',s (flat-store t) (flat-offset t)))] [(tcomp-tensor args) - - (merge-flats (map tp-force args))]))) + `(merge-flats (map force/eval ',args))]))) (define bounded-idx*^ (λ (shape idx*) @@ -361,11 +341,13 @@ (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) +;; We may have to replace tp-scalarize with scalarize from flat-tensors, because +;; the force/eval used in its definition is undesirable. (define tp-scalarize (λ (tp) (cond [(and (tpromise? tp) (null? (tpromise-shape tp))) - (tp-scalarize (tp-force tp))] + (tp-scalarize (force/eval tp))] [(and (flat? tp) (null? (flat-shape tp))) (vref (flat-store tp) 0)] [else tp]))) @@ -433,22 +415,22 @@ (ensure-tpromise tp-z)))]))))) (define ext2-∇-forcer - (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1) - (let* ((s0 (tp-shape tp-t0)) + (λ (fᵈ r0 r1 shape-fn t0 t1 z out0 out1) + (let* ((f0 (ensure-flat t0)) + (f1 (ensure-flat t1)) + (fz (ensure-flat z)) + + (s0 (flat-shape f0)) (sf0 (min-shape r0 s0)) (stride0 (flat:size-of sf0)) - (s1 (tp-shape tp-t1)) + (s1 (flat-shape t1)) (sf1 (min-shape r1 s1)) (stride1 (flat:size-of sf1)) (sf-z (shape-fn sf0 sf1)) (stride-z (flat:size-of sf-z)) - (f0 (ensure-flat (tp-force tp-t0))) - (f1 (ensure-flat (tp-force tp-t1))) - (fz (ensure-flat (tp-force tp-z))) - (v0 (flat-store f0)) (v1 (flat-store f1)) (vz (flat-store fz)) @@ -482,15 +464,16 @@ vz (+ offz iz) stride-z))) - (set-box! out0 + (set-ext2-∇-result-res! out0 (tp-scalarize (flat s0 g0 0))) - (set-box! out1 + (set-ext2-∇-result-res! out1 (tp-scalarize (flat s1 g1 0))))))))) +(struct ext2-∇-result (res) #:mutable #:transparent) (define tp-d-ext2^ (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) - (let* ((out0 (box 'uncalculated)) - (out1 (box 'uncalculated))) + (let* ((out0 (ext2-∇-result 'uncalculated)) + (out1 (ext2-∇-result 'uncalculated))) (values (tpromise (tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 0) (tp-shape tp-t0)) @@ -551,12 +534,12 @@ (define force*1 (λ (t f) - (f (tp-force t)))) + (f (force/eval t)))) (define force*2 (λ (ts f) (let-values (((t1 t2) (ts))) - (f (tp-force t1) (tp-force t2))))) + (f (force/eval t1) (force/eval t2))))) (include "test/test-0-lazy.rkt") @@ -567,7 +550,7 @@ (flat:ref ref) (flat:refr refr))) (provide tensor - tp-force + force/eval tpromise? (rename-out (tp-scalarize scalarize) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt index 93851b1..a03fa26 100644 --- a/lazy/tensors/A-equality.rkt +++ b/lazy/tensors/A-equality.rkt @@ -5,7 +5,7 @@ (define tp-tensor-equal? (λ (tp-actual tp-expected) - (flat:tensor-equal? (tp-force tp-actual) (tp-force tp-expected)))) + (flat:tensor-equal? (force/eval tp-actual) (force/eval tp-expected)))) (require rackunit) (define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index 09da45a..b5777f1 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -3,16 +3,16 @@ (define test-lt (tensor 1 2 3)) (check-true (flat? (tpromise-tensor test-lt))) - (check-equal? (flat-store (tp-force test-lt)) (vector 1 2 3)) + (check-equal? (flat-store (force/eval test-lt)) (vector 1 2 3)) (check-true (flat? (tpromise-tensor test-lt))) (check-exn exn:fail? (λ () (tensor test-lt 4))) (check-exn exn:fail? (λ () (tensor 4 test-lt))) - (check-equal? (tp-force (tp-tref test-lt 2)) 3) + (check-equal? (force/eval (tp-tref test-lt 2)) 3) (check-exn exn:fail? (λ () (tp-tref test-lt 5))) (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) - (check-equal? (tp-force (tp-tref (tp-tref test-nested-lt 0) 2)) 3) + (check-equal? (force/eval (tp-tref (tp-tref test-nested-lt 0) 2)) 3) (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0))) (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2))) (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) @@ -21,7 +21,7 @@ (check-equal? (tp-tlen test-nested-lt) 2) (define test-lt-from-list (list->tpromise '(5 6 7 8))) - (check-equal? (flat-store (tp-force test-lt-from-list)) (vector 5 6 7 8)) + (check-equal? (flat-store (force/eval test-lt-from-list)) (vector 5 6 7 8)) (define test-nested-lt-from-list (list->tpromise `(,test-lt ,test-lt ,test-lt))) (check-equal? (tpromise-shape test-nested-lt-from-list) '(3 3)) @@ -44,8 +44,8 @@ (define test-refs '(0 2)) (define test-tp-trefs (tp-trefs test-built-tensor test-refs)) (check-true (tcomp? (tpromise-tensor test-tp-trefs))) - (check-equal? (tpromise-shape test-tp-trefs) (flat-shape (tp-force test-tp-trefs))) - (check-equal? (flat-store (tp-force test-tp-trefs)) (vector 0 1 2 6 7 8)) + (check-equal? (tpromise-shape test-tp-trefs) (flat-shape (force/eval test-tp-trefs))) + (check-equal? (flat-store (force/eval test-tp-trefs)) (vector 0 1 2 6 7 8)) (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4)))) (define sum-f @@ -55,11 +55,11 @@ (+ sum (vref in-v i)))))) (define sum (tp-ext1-ρ sum-f 1)) - (check-equal? (flat-store (tp-force (sum test-nested-lt))) (vec 6.0 15.0)) + (check-equal? (flat-store (force/eval (sum test-nested-lt))) (vec 6.0 15.0)) (define id-f (lambda (v) v)) (define id-ρ (tp-ext1-ρ id-f 1 (λ (s) s))) - (check-equal? (flat-store (tp-force (id-ρ test-nested-lt))) (vec 1 2 3 4 5 6)) + (check-equal? (flat-store (force/eval (id-ρ test-nested-lt))) (vec 1 2 3 4 5 6)) (define t0 (build-tpromise '(2 3 4) @@ -69,7 +69,7 @@ (define *-ρ (tp-ext2-ρ * 0 0)) (define t0sqr (*-ρ t0 t0)) - (flat:check-tensor-equal? (tp-force t0sqr) + (flat:check-tensor-equal? (force/eval t0sqr) (flat:reshape '(2 3 4) (flat:tensor @@ -104,7 +104,7 @@ (*-2-1 t1 t2)) (check-equal? (tpromise-shape r-1-2) '(5 6)) - (flat:check-tensor-equal? (tp-force r-1-2) + (flat:check-tensor-equal? (force/eval r-1-2) (flat:reshape '(5 6) (flat:tensor @@ -130,7 +130,7 @@ (*-2-1 t3 t4)) (check-equal? (tpromise-shape r-3-4) '(3 5 6)) - (flat:check-tensor-equal? (tp-force r-3-4) + (flat:check-tensor-equal? (force/eval r-3-4) (flat:reshape '(3 5 6) (flat:tensor @@ -168,11 +168,11 @@ (λ (t) (build-tpromise (tpromise-shape t) (λ (_) 1.0)))) - (flat:check-tensor-equal? (tp-force (d-sqr r1-td (one-like r1-td))) + (flat:check-tensor-equal? (force/eval (d-sqr r1-td (one-like r1-td))) (flat:tensor 6.0 8.0 10.0)) (let ((gsqr (d-sqr r2-td (one-like r2-td)))) - (flat:check-tensor-equal? (tp-force gsqr) + (flat:check-tensor-equal? (force/eval gsqr) (flat:reshape '(2 3) (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) @@ -180,15 +180,15 @@ (define d+ (tp-ext2-∇ +ᵈ 0 0 scalar-shape)) (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) - (flat:check-tensor-equal? (tp-force da) + (flat:check-tensor-equal? (force/eval da) (flat:tensor 1.0 1.0 1.0)) - (flat:check-tensor-equal? (tp-force db) + (flat:check-tensor-equal? (force/eval db) (flat:tensor 1.0 1.0 1.0))) (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) - (flat:check-tensor-equal? (tp-force da) + (flat:check-tensor-equal? (force/eval da) (flat:tensor 2.0 2.0 2.0)) - (flat:check-tensor-equal? (tp-force db) + (flat:check-tensor-equal? (force/eval db) (flat:reshape '(2 3) (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) @@ -200,8 +200,8 @@ (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) - (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor 1.0 2.0 3.0))) - (flat:check-tensor-equal? (tp-force gu) (tp-force (tensor 2.0 3.0 4.0)))) + (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 2.0 3.0))) + (flat:check-tensor-equal? (force/eval gu) (force/eval (tensor 2.0 3.0 4.0)))) (define sum-1-∇ (λ (g t it st vz iz sz) @@ -212,10 +212,10 @@ (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) 1.0))) - (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor 1.0 1.0 1.0)))) + (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 1.0 1.0)))) (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0)))) - (flat:check-tensor-equal? (tp-force gt) (tp-force (tensor (tensor 2.0 2.0 2.0) + (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor (tensor 2.0 2.0 2.0) (tensor 1.0 1.0 1.0)))))) diff --git a/malted/test/test-O-init.rkt b/malted/test/test-O-init.rkt index 342300e..1422455 100644 --- a/malted/test/test-O-init.rkt +++ b/malted/test/test-O-init.rkt @@ -3,16 +3,16 @@ ;; TODO: Make this better. We musn't break abstraction boundaries (require "../lazy/tensors/0-lazy.rkt") (define v (init-shape (list 1000 4))) - (define mean-v (tp-force (abs (/ (sum (sum v)) 4000)))) - (define variance-v (tp-force (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v)))) + (define mean-v (force/eval (abs (/ (sum (sum v)) 4000)))) + (define variance-v (force/eval (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v)))) (check-true (< mean-v 0.05)) (check-true (and (>= variance-v 0.4) (<= variance-v 0.6))) ;; Here variance will be 2/8 = 0.25 (define r (init-shape (list 1000 4 2))) - (define mean-r (tp-force (abs (/ (sum (sum (sum r))) 8000)))) - (define variance-r (tp-force (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r)))) + (define mean-r (force/eval (abs (/ (sum (sum (sum r))) 8000)))) + (define variance-r (force/eval (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r)))) (check-true (< mean-r 0.05)) (check-true (and (>= variance-r 0.22) From 9723892312082766b81a914a5fc827d8723bc4f8 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 9 Sep 2023 06:30:24 -0400 Subject: [PATCH 62/83] [add-lazy]compile racket WIP --- lazy/autodiff/B-prims.rkt | 3 - lazy/autodiff/E-print.rkt | 2 +- lazy/tensors/0-lazy.rkt | 520 ++++++++++++++++++++++++++------------ 3 files changed, 365 insertions(+), 160 deletions(-) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index b7fc2fc..ff532af 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -29,9 +29,6 @@ (let ((ra (ρ da))) (dual (ρ-fn ra) (λ (d z σ) - ;; TODO: need force*-1 here while calling ∇-fn - #;(let ((ga (∇-fn ra z))) - ((κ da) da #;ga (tp-force ga) σ)) (force*1 (∇-fn ra z) (λ (ga) ((κ da) da ga σ)))))))) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 2dd32ae..9ae1afa 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -2,7 +2,7 @@ (require "A-autodiff.rkt") (require "../tensors/0-lazy.rkt") -(require "../../flat-tensors/ext-impl.rkt") +(require (except-in "../../flat-tensors/ext-impl.rkt" scalarize)) (define max-tensor-print-length (make-parameter 5)) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index e303e42..57111dc 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -70,44 +70,58 @@ (-> Number * tpromise))) (define tensor (λ args - (unless (ensure-shape args) - (error 'tensor - "Mismatched shapes: ~a~%" - args)) - - (let ((inner-flat (tensor-inner-flat args))) - (cond - ((flat? inner-flat) - (tpromise inner-flat (flat-shape inner-flat))) - (else - (let* ((inner-shape (tpromise-shape (car args))) - (outer (length args)) - (new-shape (cons outer inner-shape))) - (tpromise inner-flat new-shape))))))) + (list->tpromise args))) #; -(: tensor-inner-flat (-> (U (Listof tpromise) (Listof Number)) - (U flat tcomp))) +(: tensor-inner-flat (-> (Listof (U tpromise Number)) + (U flat tcomp-list->tensor))) (define tensor-inner-flat - (λ (args) + (λ (lst) (cond - [(number? (car args)) (apply flat:tensor args)] - [else (tcomp-tensor args)]))) + [(andmap number? lst) (apply flat:tensor lst)] + [else (tcomp-list->tensor lst)]))) #; -(: ensure-shape (-> (U (Listof tpromise) (Listof Number)) Boolean)) +(: ensure-shape (-> (U (Listof tpromise) (Listof Number)) Void)) (define ensure-shape (λ (args) - (and (not (null? args)) - (cond - ((number? (car args)) - (andmap number? (cdr args))) - ((tpromise? (car args)) - (let ((s (tp-shape (car args)))) - (andmap (λ (t) - (and (tpromise? t) - (equal? (tp-shape t) s))) - (cdr args)))) - (else #f))))) + (when (null? args) + (error 'tensor "Tensors cannot be empty")) + (let ((checked-shape + (λ (x) (if (tpromise? x) + (tpromise-shape x) + '()))) + (scalar-like? + (λ (x) + (or (number? x) + (and (tpromise? x) + (null? (tpromise-shape x))))))) + (unless (and (not (null? args)) + (cond + ((scalar-like? (car args)) + (andmap scalar-like? (cdr args))) + ((tpromise? (car args)) + (let ((s (checked-shape (car args)))) + (andmap (λ (t) + (and (tpromise? t) + (equal? (checked-shape t) s))) + (cdr args)))) + (else #f))) + (error 'tensor + "Cannot construct a tensor out of these elements: ~a~%" + args))))) + +(define list->tpromise + (λ (lst) + (ensure-shape lst) + (let ((inner-flat (tensor-inner-flat lst))) + (cond + ((flat? inner-flat) + (tpromise inner-flat (flat-shape inner-flat))) + (else + (let* ((inner-shape (tp-shape (car lst))) + (outer (length lst)) + (new-shape (cons outer inner-shape))) + (tpromise inner-flat new-shape))))))) #; (: ensure-flat (-> (U flat Number) flat)) @@ -130,90 +144,284 @@ (printf "~n####PP tensor: ") (pretty-print tp)) (match tp - [(tpromise t-tcomp _) - #:when (tcomp? t-tcomp) - (let* ((instrs (compile-expr t-tcomp '())) - (res (run-instrs instrs))) - (set-tpromise-tensor! tp res) - res)] [(tpromise t _) - #:when (or (flat? t) (scalar? t)) t] + #:when (or (flat? t) (scalar? t) (tcomp? t)) + + (let-values (((instrs locals env) + (run-compiler (compile-expr t (count-references t (hasheq))) + '() + '()))) + (let ((res (run-instrs instrs locals env))) + (set-tpromise-tensor! tp res) + res))] ;; NOTE: This case runs when we use tp-scalarize to turn ;; the tensor to a scalar - (else tp)))) + (_ tp)))) -#; -(: tcomp-force (-> tcomp (U flat Number))) +(struct counter-data (binding-name + ref-count + (compiled? #:mutable #:auto)) + #:transparent) + +;; Count references so that the tcomp AST nodes that refer to the same memory +;; location i.e. common AST nodes get extracted by let-binding them in the +;; compiled output racket code. +(define count-tcomp-references + (λ (tc counter) + (match-let (((counter-data tc-binding-name tc-ref-count _) + (hash-ref counter tc + (λ () + (counter-data (gensym 'local) 0))))) + (let ((counter^ (hash-set counter tc + (counter-data tc-binding-name + (add1 tc-ref-count))))) + (match tc + [(tcomp-list->tensor lst) + (for/fold + ((counter^^ counter^)) + ((l lst)) + (count-references l counter^^))] + [(tcomp-build-tensor s f) counter^] + [(tcomp-tref tp i) + (count-references tp counter^)] + [(tcomp-trefs tp b) + (count-references tp counter^)] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (count-references + tp-z + (count-references + tp-t1 + (count-references tp-t0 counter^)))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (count-references + zp + (count-references tp counter^))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (count-references + tp-u + (count-references tp-t counter^))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (count-references + tp-u + (count-references tp-t counter^))] + [(tcomp-ext1-ρ-scalar f tp) + (count-references tp counter^)] + [(tcomp-ext1-ρ f m shape-fn tp) + (count-references tp counter^)] + [(tcomp-reshape s tp) + (count-references tp counter^)]))))) + +(define count-references + (λ (t counter) + (match t + ((tpromise tc _) + (count-references tc counter)) + ((tcomp) (count-tcomp-references t counter)) + (_ counter)))) + +;; TODO: Add tcomp nodes for let and var so that common refs to tcomp can become +;; tcomp-vars and be bound in a tcomp-let. +;; +;; TODO: Add another intermediate pass that generates tcomp-let and tcomp-var +;; nodes in such a way that it converts the abstract syntax graph to a abstract +;; syntax tree. +#| +(tcomp ..) => + (tcomp-let ((var (tcomp ...))) (tcomp ... var ....)) + +(tpromise (tcomp ... (tcomp-var name) ...)) +|# + +(define extend-env + (λ (k v env) + `((,k . ,v) . ,env))) +(define extend-locals extend-env) + +(define exists-in-env? + (λ (ft env) + (match env + ('() #f) + (`((,k . ,v) . ,_) #:when (eq? ft v) k) + (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) + +;; (: run-compiler (All (Instrs Env) (-> Env (Values Instrs Env)))) +(struct Compiler (run-compiler) #:transparent) + +(define run-compiler + (λ (c locals env) + ((Compiler-run-compiler c) locals env))) + +(define inj-compiler-val + (λ (v) + (Compiler (λ (locals env) (values v locals env))))) + +;; TODO: In the future scalars must be in the environment rather than the +;; compiled instrs. This will make the instruction signature fully independent +;; of the data in the environmnt. +(define inj-compiler-flat + (λ (ft) + (Compiler + (λ (locals env) + (cond + ((exists-in-env? ft env) => (lambda (var) (values var locals env))) + ((and (flat? ft) (null? (flat-shape ft))) + (values ft locals env)) + (else + (let ((new-var (gensym 'ft))) + (values new-var locals (extend-env new-var ft env))))))))) + +(define inj-compiler-instrs + (λ (instrs cd) + (match-let (((counter-data binding-var ref-count _) cd)) + (Compiler (λ (locals env) + (pretty-print instrs) + (cond + ((<= ref-count 1) + (println "ref <= 1") + (values instrs locals env)) + ((assv binding-var locals) + (println "assv") + (values binding-var locals env)) + (else + (println "else") + (values binding-var + (extend-locals binding-var instrs locals) + env)))))))) + +(define ->c + (λ (c f) + (Compiler (λ (locals env) + (let-values (((instrs locals^ env^) (run-compiler c locals env))) + (run-compiler (f instrs) locals^ env^)))))) + + +(define compile-tcomp + (λ (tc counter) + (let ((tc-counter-data + (hash-ref counter tc + (λ () + (counter-data (gensym 'illegal) 0))))) + (match tc + [(tcomp-list->tensor lst) + (let ((instrs-list-compiler + (for/foldr ((list-compiler (inj-compiler-val '()))) + ((arg lst)) + (->c + (compile-expr arg counter) + (λ (instrs) + (->c + list-compiler + (λ (instrs-list) + (inj-compiler-val (cons instrs instrs-list))))))))) + (->c + instrs-list-compiler + (λ (instrs-list) + (inj-compiler-instrs `(flat:list->tensor (list ,@instrs-list)) + tc-counter-data))))] + [(tcomp-build-tensor s f) + (inj-compiler-flat (flat:build-tensor s f))] + [(tcomp-tref tp i) + (->c + (compile-expr tp counter) + (λ (instrs) + (inj-compiler-instrs `(flat:tref ,instrs ,i) tc-counter-data)))] + [(tcomp-trefs tp b) + (->c + (compile-expr tp counter) + (λ (instrs) + (inj-compiler-instrs `(flat:trefs ,instrs ',b) tc-counter-data)))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->c + (compile-expr tp-t0 counter) + (λ (t0-instrs) + (->c + (compile-expr tp-t1 counter) + (λ (t1-instrs) + (->c + (compile-expr tp-z counter) + (λ (z-instrs) + (inj-compiler-instrs + `(let* ([b (if (zero? ,i) ,out0 ,out1)] + [v (ext2-∇-result-res b)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out0 ,out1) + (ext2-∇-result-res b)) + (else v))) + tc-counter-data)))))))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (->c + (compile-expr tp counter) + (λ (t-instrs) + (->c + (compile-expr zp counter) + (λ (z-instrs) + (inj-compiler-instrs + `(scalarize + (flat-ext1-∇ ,f ,m ,shape-fn + (ensure-flat ,t-instrs) + (ensure-flat ,z-instrs))) + tc-counter-data)))))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (->c + (compile-expr tp-t counter) + (λ (t-instrs) + (->c + (compile-expr tp-u counter) + (λ (u-instrs) + (inj-compiler-instrs + `(,f ,t-instrs ,u-instrs) + tc-counter-data)))))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (->c + (compile-expr tp-t counter) + (λ (t-instrs) + (->c + (compile-expr tp-u counter) + (λ (u-instrs) + (inj-compiler-instrs + `(scalarize + (flat-ext2-ρ ,f ,m ,n ,shape-fn + (ensure-flat ,t-instrs) + (ensure-flat ,u-instrs))) + tc-counter-data)))))] + [(tcomp-ext1-ρ-scalar f tp) + (->c + (compile-expr tp counter) + (λ (instrs) + (inj-compiler-instrs `(,f ,instrs) + tc-counter-data)))] + [(tcomp-ext1-ρ f m shape-fn tp) + (->c + (compile-expr tp counter) + (λ (instrs) + (inj-compiler-instrs `(scalarize + (flat-ext1-ρ ,f ,m ,shape-fn + (ensure-flat ,instrs))) + tc-counter-data)))] + [(tcomp-reshape s tp) + (->c + (compile-expr tp counter) + (λ (instrs) + (inj-compiler-instrs `(flat ',s + (flat-store ,instrs) + (flat-offset ,instrs)) + tc-counter-data)))])))) + +(define print-compiler? (make-parameter #f)) (define compile-expr - (λ (tc t-env) + (λ (tc counter) + (when (print-compiler?) + (pretty-print tc)) (match tc - [(tcomp-list->tensor lst) - `(flat:list->tensor - (map (λ (l) (force/eval l #f)) ',lst))] - [(tcomp-build-tensor s f) - `(flat:build-tensor ',s ,f)] - [(tcomp-tref tp i) - `(flat:tref (force/eval ,tp) ,i)] - [(tcomp-trefs tp b) - `(flat:trefs (force/eval ,tp) ',b)] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - `(let* ([b (if (zero? ,i) ,out0 ,out1)] - [v (ext2-∇-result-res b)]) - (cond - ((eqv? v 'uncalculated) - (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn - (force/eval ,tp-t0) - (force/eval ,tp-t1) - (force/eval ,tp-z) - ,out0 ,out1) - (ext2-∇-result-res b)) - (else v)))] - [(tcomp-ext1-∇-prealloc tp zp f m shape-fn) - `(scalarize - (flat-ext1-∇ ,f ,m ,shape-fn - (ensure-flat (force/eval ,tp)) - (ensure-flat (force/eval ,zp))))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - `(let ([t (ensure-flat (force/eval ,tp #f))] - [z (ensure-flat (force/eval ,zp #f))]) - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape ,m in-shape)) - (out-shape (,shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ ,f base-shape out-shape))) - (scalarize (flat-ext1-∇ flat-f ,m ,shape-fn t z))))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - `(,f (force/eval ,tp-t) (force/eval ,tp-u))] - [(tcomp-ext2-ρ-prealloc tp-t tp-u f m n shape-fn) - `(scalarize - (flat-ext2-ρ ,f ,m ,n ,shape-fn - (ensure-flat (force/eval ,tp-t)) - (ensure-flat (force/eval ,tp-u))))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - `(let ([t (ensure-flat (force/eval ,tp-t #f))] - [u (ensure-flat (force/eval ,tp-u #f))]) - (let* ((t-shape (min-shape ,m (flat-shape t))) - (u-shape (min-shape ,n (flat-shape u))) - (out-shape (,shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ ,f t-shape u-shape out-shape))) - (scalarize - (flat-ext2-ρ flat-f ,m ,n ,shape-fn t u))))] - [(tcomp-ext1-ρ-scalar f tp) - `(,f (force/eval ,tp))] - [(tcomp-ext1-ρ-prealloc f m shape-fn tp) - `(scalarize (flat-ext1-ρ ,f ,m ,shape-fn (ensure-flat (force/eval ,tp))))] - [(tcomp-ext1-ρ f m shape-fn tp) - `(let ([t (ensure-flat (force/eval ,tp #f))]) - (let* ((in-shape (flat-shape t)) - (base-shape (min-shape ,m in-shape)) - (out-shape (,shape-fn base-shape)) - (flat-f (functional->preallocated-1-ρ ,f base-shape out-shape))) - (scalarize - (flat-ext1-ρ flat-f ,m ,shape-fn t))))] - [(tcomp-reshape s tp) - `(let ([t (force/eval ,tp #f)]) - (flat ',s (flat-store t) (flat-offset t)))] - [(tcomp-tensor args) - `(merge-flats (map force/eval ',args))]))) + [(tpromise tc _) (compile-expr tc counter)] + [tc #:when (flat? tc) + (inj-compiler-flat tc)] + [tc #:when (or (pair? tc) (scalar? tc)) + (inj-compiler-val tc)] + [(tcomp) (compile-tcomp tc counter)]))) (define bounded-idx*^ (λ (shape idx*) @@ -319,25 +527,35 @@ (null? (tpromise-shape tp-u))) (tpromise (tcomp-ext2-ρ-scalar f tp-t tp-u) '())] [(flat:expects-preallocated? f) - (let* ([s0 (tp-shape tp-t)] - [s1 (tp-shape tp-u)] - [sf0 (min-shape m s0)] - [sf1 (min-shape n s1)] - [sf-out (shape-fn sf0 sf1)]) + (let* ((s0 (tp-shape tp-t)) + (s1 (tp-shape tp-u)) + (sf0 (min-shape m s0)) + (sf1 (min-shape n s1)) + (sf-out (shape-fn sf0 sf1))) (tpromise - (tcomp-ext2-ρ-prealloc tp-t tp-u - f m n shape-fn) + (tcomp-ext2-ρ (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + f m n shape-fn) (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))] [else - (let* ([s0 (tp-shape tp-t)] - [s1 (tp-shape tp-u)] - [sf0 (min-shape m s0)] - [sf1 (min-shape n s1)] - [sf-out (shape-fn sf0 sf1)]) + (let* ((s0 (tp-shape tp-t)) + (s1 (tp-shape tp-u)) + (sf0 (min-shape m s0)) + (sf1 (min-shape n s1)) + (sf-out (shape-fn sf0 sf1)) + (t-shape (min-shape m s0)) + (u-shape (min-shape n s1)) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-ρ + f + t-shape + u-shape + out-shape))) (tpromise - (tcomp-ext2-ρ (ensure-tpromise tp-t) (ensure-tpromise tp-u) - f m n shape-fn) + (tcomp-ext2-ρ (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + flat-f m n shape-fn) (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) @@ -403,7 +621,10 @@ (λ (tp-t tp-u tp-z) (cond ((flat:expects-preallocated? f) - (tp-f f tp-t tp-u tp-z)) + (tp-f f + (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + (ensure-tpromise tp-z))) [else (let* ((t-shape (min-shape m (tp-shape tp-t))) (u-shape (min-shape n (tp-shape tp-u))) (out-shape (shape-fn t-shape u-shape)) @@ -486,37 +707,6 @@ ((scalar? v) (tpromise (ensure-flat v) '())) (else v)))) - -(define flat-gradient-maker2 - (λ (f m n) - (cond - ((and (zero? m) (zero? n)) - (λ (g0 - g1 - v0 i0 stride0 - v1 i1 stride1 - vz iz stride-z) - (let ((z (vref vz iz)) - (a (vref v0 i0)) - (b (vref v1 i1))) - (let-values (((da db) (f a b z))) - (vset! g0 i0 - (+ (vref g0 i0) da)) - (vset! g1 i1 - (+ (vref g1 i1) db)))))) - (else f)))) - -(define flat-gradient-maker1 - (λ (f0 m) - (cond - ((zero? m) - (λ (g0 v0 i0 stride0 vz iz stride-z) - (let ((z (vref vz iz)) - (a (vref v0 i0))) - (vset! g0 i0 (+ (vref g0 i0) - (f0 a z)))))) - (else f0)))) - (define tp-rank (λ (tp) (flat:len (tp-shape tp)))) @@ -526,12 +716,30 @@ (cond ((= (flat:size-of s) (flat:size-of (tpromise-shape tp))) (tpromise (tcomp-reshape s tp) s)) - (else (error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) + (else (error 'shape-error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) (define tensor? (lambda (tp) (or (tpromise? tp) (flat? tp) (scalar? tp)))) +(define get-compiled + (λ (t) + (let-values (((instrs locals env) + (run-compiler + (compile-expr t + (count-references t (hasheq))) + '() '()))) + (make-instrs instrs locals env)))) +(define run-compiled + (λ (t) + (let-values (((instrs locals env) + (run-compiler + (compile-expr t + (count-references t (hasheq))) + '() '()))) + (make-instrs instrs locals env)))) +;; TODO: may have to remove call to force*1 & force*2 so that this can be +;; handled at compile time (define force*1 (λ (t f) (f (force/eval t)))) From fc866740d6853bec1b22dea5d3b0c6a16fa457c4 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 9 Sep 2023 06:34:17 -0400 Subject: [PATCH 63/83] [add-lazy]compile racket WIP --- lazy/tensors/0-lazy.rkt | 140 ++++++++------- lazy/tensors/test/test-0-lazy.rkt | 281 +++++++++++++++++++++++++----- malted/test/test-O-init.rkt | 33 ++-- 3 files changed, 342 insertions(+), 112 deletions(-) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 57111dc..f189f01 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -1,10 +1,50 @@ #lang racket (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +#| +Questions: + +* How do I create a preallocated function example for ext2-∇? Also, how do the preallocated functions work? +A. Here are what the formal parameters of a binary preallocated ∇-function mean: + + - g0, g1: These are the empty gradient tensor stores which need to be filled +with the gradients of the corresponding ρ-function w.r.t. the first and second +arguments respectively. + + - t, it, st: the store, beginning offset and total size of the flat +representation of the first tensor argument respectively which we will need to loop through +the scalar elements. + + - u, iu, su: the store, beginning offset and total size of the flat +representation of the second tensor argument respectively which we will need to +loop through the scalar elements. + + - z, iz, sz: the store, beginning offset and total size of the flat +representation of the accumulator respectively which we will need to +loop through the scalar elements. + + Here are the invariants of the formal parameters: + + - The flat tensors corresponding to the stores g0 and g1 have the same shape +as the first and second input tensors respectively + + - The flat tensor corresponding to the store z has the same shape as the +result of invoking the corresponding ρ-function with the two tensor arguments +* Here are a few problems which need to be addressed while checking compiler +invariants after compiling the expression "(((l2-loss plane) r2d1 (tensor 1.0 +1.0)) plane-theta-0)" from the test-C-loss.rkt file: + + - The env contains flat tensors with the shape '() i.e. they scalars in +disguise + + - Somehow copies of a flat tensor are being added to the env rather than the +instructions refering to the same gensym variable + +|# ;; tensor computations -(struct tcomp ()) +(struct tcomp () #:transparent) #; (: lst (U (Listof tpromise) (Listof Number))) (struct tcomp-list->tensor tcomp (lst) #:transparent) @@ -30,18 +70,12 @@ (Vector Number) Natural (Listof Natural)))) (struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) #:transparent) -(struct tcomp-ext1-∇-prealloc tcomp (tp zp f m shape-fn) #:transparent) (struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) (struct tcomp-ext2-ρ-scalar tcomp (f tp-t tp-u) #:transparent) -(struct tcomp-ext2-ρ-prealloc tcomp (tp-t tp-u f m n shape-fn) #:transparent) (struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) (struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) -(struct tcomp-ext1-ρ-prealloc tcomp (f m shape-fn tp) #:transparent) (struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) -#; -(: args (U (Listof tpromise) (Listof Number))) -(struct tcomp-tensor tcomp (args) #:transparent) (struct tpromise ((tensor #:mutable) shape) #:guard @@ -133,10 +167,22 @@ (define-namespace-anchor a) +(define make-instrs + (λ (instrs locals env) + (let* ([local-binds (for/fold ((binds '())) + ((l/instrs locals)) + `((,(car l/instrs) ,(cdr l/instrs)) . ,binds))] + [env-binds (for/fold ((binds '())) + ((t/instrs env)) + `((,(car t/instrs) ,(cdr t/instrs)) . ,binds))]) + `(let* ,env-binds + (let* ,local-binds + ,instrs))))) + (define run-instrs - (lambda (instrs) - (let ([env (namespace-anchor->namespace a)]) - (eval instrs env)))) + (lambda (instrs locals env) + (let ([static-env (namespace-anchor->namespace a)]) + (eval (make-instrs instrs locals env) static-env)))) (define force/eval (lambda (tp (print? #f)) @@ -460,17 +506,6 @@ [(tpromise? v) (tpromise-shape v)] [else (flat:shape v)]))) -(define list->tpromise - (λ (lst) - (cond - [(null? lst) - (error 'list->ltensor "No elements found")] - [else - (tpromise (tcomp-list->tensor lst) - `(,(length lst) - . ,(tp-shape - (car lst))))]))) - (define build-tpromise (λ (s f) (tpromise (tcomp-build-tensor s f) s))) @@ -501,20 +536,24 @@ '())] [(flat:expects-preallocated? f) (tpromise - (tcomp-ext1-ρ-prealloc f m shape-fn tp) + (tcomp-ext1-ρ f m shape-fn tp) (merge-shapes (tp-shape tp) m (shape-fn (min-shape m (tp-shape tp)))))] [else - (tpromise - (tcomp-ext1-ρ f m shape-fn tp) - (merge-shapes - (tp-shape tp) - m - (shape-fn - (min-shape m (tp-shape tp)))))])))) + (let* ((in-shape (tpromise-shape tp)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) + (tpromise + (tcomp-ext1-ρ flat-f m shape-fn tp) + (merge-shapes + (tp-shape tp) + m + (shape-fn + (min-shape m (tp-shape tp))))))])))) (define tp-ext2-ρ (λ (f m n [shape-fn scalar-shape]) @@ -573,31 +612,6 @@ (define scalar-shape (λ (s0 [s1 '()]) '())) -(define left-shape - (λ (s0 s1) s0)) - -(define right-shape - (λ (s0 s1) s1)) - -(define flat-function-maker2 - (λ (f m n) - (cond - ((and (zero? m) (zero? n)) - (λ (v0 i0 stride0 v1 i1 stride1 - v-out i-out stride-out) - (vset! v-out i-out - (f (vref v0 i0) (vref v1 i1))))) - (else - f)))) - -(define flat-function-maker1 - (λ (f m) - (cond - ((zero? m) - (λ (v0 i0 stride0 v-out i-out stride-out) - (vset! v-out i-out (f (vref v0 i0))))) - (else f)))) - (define tp-ext1-∇ (λ (f m [shape-fn scalar-shape]) (λ (tp zp) @@ -605,12 +619,16 @@ ((number? tp) (f tp zp)) ((flat:expects-preallocated? f) (tpromise - (tcomp-ext1-∇-prealloc tp zp f m shape-fn) + (tcomp-ext1-∇ tp (ensure-tpromise zp) f m shape-fn) (tp-shape tp))) (else - (tpromise - (tcomp-ext1-∇ tp zp f m shape-fn) - (tp-shape tp))))))) + (let* ((in-shape (tpromise-shape tp)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (tpromise + (tcomp-ext1-∇ tp (ensure-tpromise zp) flat-f m shape-fn) + (tp-shape tp)))))))) (define tp-ext2-∇ (λ (f m n [shape-fn scalar-shape]) @@ -749,7 +767,8 @@ (let-values (((t1 t2) (ts))) (f (force/eval t1) (force/eval t2))))) -(include "test/test-0-lazy.rkt") +;; TODO: uncomment +;(include "test/test-0-lazy.rkt") (provide start-vector-manager vector-manager-report) @@ -783,3 +802,6 @@ (flat:size-of size-of))) (provide force*1 force*2) + +;; TODO: delete later. For debugging only +(provide print-compiler? compile-expr make-instrs run-instrs force/eval tpromise-tensor run-compiler count-references get-compiled) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index b5777f1..cbbbde1 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -1,18 +1,50 @@ (module+ test (require rackunit) + ;; TODO: Add a comment above each test case describing what the test case is testing + (define-check (check-compiler-invariants tp) + (let-values (((instrs locals env) (run-compiler + (compile-expr tp + (count-references tp (hasheq))) + '() '()))) + (with-check-info + (('env (nested-info + (map (λ (name/flat) + (make-check-info (car name/flat) (cdr name/flat))) + env))) + ('instrs instrs)) + (for ((name/flat env)) + (unless (and (flat:flat? (cdr name/flat)) + (not (null? (flat-shape (cdr name/flat))))) + (fail-check (format (string-append "Value associated with the variable" + " ~a should be a flat tensor. " + "Associated value found: ~a") + (car name/flat) (cdr name/flat))))) + (define unique-flats (list->seteq (map cdr env))) + (unless (equal? (set-count unique-flats) + (length (filter flat? (map cdr env)))) + (fail-check (string-append "Duplicate flat tensors found" + " in environment. Variables in environment" + " should be paired with unique" + " flat tensors")))))) + (define test-lt (tensor 1 2 3)) + (check-compiler-invariants test-lt) (check-true (flat? (tpromise-tensor test-lt))) (check-equal? (flat-store (force/eval test-lt)) (vector 1 2 3)) (check-true (flat? (tpromise-tensor test-lt))) (check-exn exn:fail? (λ () (tensor test-lt 4))) (check-exn exn:fail? (λ () (tensor 4 test-lt))) - (check-equal? (force/eval (tp-tref test-lt 2)) 3) + (define test-tcomp-tref (tp-tref test-lt 2)) + (check-compiler-invariants test-tcomp-tref) + (check-equal? (force/eval test-tcomp-tref) 3) (check-exn exn:fail? (λ () (tp-tref test-lt 5))) (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) - (check-equal? (force/eval (tp-tref (tp-tref test-nested-lt 0) 2)) 3) + (define test-tcomp-tref-nested (tp-tref (tp-tref test-nested-lt 0) 2)) + (check-compiler-invariants test-tcomp-tref-nested) + (check-equal? (force/eval test-tcomp-tref-nested) 3) (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0))) (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2))) (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) @@ -21,33 +53,75 @@ (check-equal? (tp-tlen test-nested-lt) 2) (define test-lt-from-list (list->tpromise '(5 6 7 8))) + (check-compiler-invariants test-lt-from-list) (check-equal? (flat-store (force/eval test-lt-from-list)) (vector 5 6 7 8)) (define test-nested-lt-from-list (list->tpromise `(,test-lt ,test-lt ,test-lt))) + (check-compiler-invariants test-nested-lt-from-list) + (check-equal? (flat-store (force/eval test-nested-lt-from-list)) + (vector 1 2 3 1 2 3 1 2 3)) (check-equal? (tpromise-shape test-nested-lt-from-list) '(3 3)) (check-true (bounded-idx*? test-nested-lt-from-list (list 0 1))) (check-false (bounded-idx*? test-nested-lt-from-list (list 1 3))) (check-false (bounded-idx*? test-nested-lt-from-list (list 1 1 0))) + (define test-tcomp-partial-eval + (begin + (force/eval test-nested-lt-from-list) + (force/eval test-nested-lt) + (force/eval test-lt) + (tp-tref + (tp-tref (tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9)) + test-nested-lt-from-list + (list->tpromise (list (tp-tref test-nested-lt 0) + (tp-tref test-nested-lt 1) + test-lt))) + 1) + 2))) + (check-compiler-invariants test-tcomp-partial-eval) + (flat:check-tensor-equal? (force/eval test-tcomp-partial-eval) + (force/eval (tensor 1 2 3))) + (define test-build-shape '(4 3)) (define test-built-tensor (build-tpromise test-build-shape - (λ (i) - (let ([row (car i)] - [column (cadr i)]) - (+ (* (sub1 (car test-build-shape)) - row) - column))))) + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column))))) + (check-compiler-invariants test-built-tensor) (check-equal? (tpromise-shape test-built-tensor) test-build-shape) (check-true (tcomp? (tpromise-tensor test-built-tensor))) + (flat:check-tensor-equal? (force/eval test-built-tensor) + (force/eval (tensor (tensor 0 1 2) + (tensor 3 4 5) + (tensor 6 7 8) + (tensor 9 10 11)))) (define test-refs '(0 2)) (define test-tp-trefs (tp-trefs test-built-tensor test-refs)) + (check-compiler-invariants test-tp-trefs) (check-true (tcomp? (tpromise-tensor test-tp-trefs))) - (check-equal? (tpromise-shape test-tp-trefs) (flat-shape (force/eval test-tp-trefs))) - (check-equal? (flat-store (force/eval test-tp-trefs)) (vector 0 1 2 6 7 8)) + (check-equal? (tpromise-shape test-tp-trefs) + (flat-shape (force/eval test-tp-trefs))) + (flat:check-tensor-equal? (force/eval test-tp-trefs) + (force/eval (tensor (tensor 0 1 2) + (tensor 6 7 8)))) (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4)))) + (define test-tp-reshape (tp-reshape '(3 2 1) (tp-trefs test-built-tensor '(1 3)))) + (check-compiler-invariants test-tp-reshape) + (flat:check-tensor-equal? (force/eval test-tp-reshape) + (force/eval (tensor (tensor (tensor 3) + (tensor 4)) + (tensor (tensor 5) + (tensor 9)) + (tensor (tensor 10) + (tensor 11))))) + (check-exn exn:fail? (λ () (tp-reshape '(4 5) test-tp-reshape))) + (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) (vset! out-v iₒ @@ -55,11 +129,27 @@ (+ sum (vref in-v i)))))) (define sum (tp-ext1-ρ sum-f 1)) - (check-equal? (flat-store (force/eval (sum test-nested-lt))) (vec 6.0 15.0)) + (define test-tp-sum (sum test-nested-lt)) + (check-compiler-invariants test-tp-sum) + (flat:check-tensor-equal? (force/eval test-tp-sum) + (force/eval (tensor 6.0 15.0))) + + (define test-tp-sum-nested (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) + (check-compiler-invariants test-tp-sum-nested) + (flat:check-tensor-equal? (force/eval test-tp-sum-nested) + (force/eval (tensor 4.0 6.0 5.0))) (define id-f (lambda (v) v)) (define id-ρ (tp-ext1-ρ id-f 1 (λ (s) s))) - (check-equal? (flat-store (force/eval (id-ρ test-nested-lt))) (vec 1 2 3 4 5 6)) + (define test-tp-id (id-ρ test-nested-lt)) + (check-compiler-invariants test-tp-id) + (flat:check-tensor-equal? (force/eval test-tp-id) + (force/eval (tensor (tensor 1 2 3) + (tensor 4 5 6)))) + + (define test-tp-id-scalar (id-ρ (sum (tensor 4 5 6)))) + (check-compiler-invariants test-tp-id-scalar) + (check-equal? (force/eval test-tp-id-scalar) 15.0) (define t0 (build-tpromise '(2 3 4) @@ -69,6 +159,7 @@ (define *-ρ (tp-ext2-ρ * 0 0)) (define t0sqr (*-ρ t0 t0)) + (check-compiler-invariants t0sqr) (flat:check-tensor-equal? (force/eval t0sqr) (flat:reshape '(2 3 4) @@ -84,8 +175,8 @@ (λ (v0 i0 s0 v1 i1 s1 vout iout sout) (for ([j0 (in-range 0 s0)]) (vset! vout (+ iout j0) - (* (vref v0 (+ i0 j0)) - (vref v1 (+ i1 (modulo j0 s1)))))))) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) (define t1 (build-tpromise '(5 6) @@ -104,15 +195,16 @@ (*-2-1 t1 t2)) (check-equal? (tpromise-shape r-1-2) '(5 6)) + (check-compiler-invariants r-1-2) (flat:check-tensor-equal? (force/eval r-1-2) (flat:reshape '(5 6) (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0))) + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) (define t3 (build-tpromise '(3 5 6) @@ -130,27 +222,32 @@ (*-2-1 t3 t4)) (check-equal? (tpromise-shape r-3-4) '(3 5 6)) + (check-compiler-invariants r-3-4) (flat:check-tensor-equal? (force/eval r-3-4) - (flat:reshape - '(3 5 6) - (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0 - - 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 - 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 - 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 - 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 - 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 - - 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 - 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 - 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 - 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + (flat:reshape + '(3 5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + + (define r-sum-2-scalar (*-ρ (sum t2) (sum (tensor 2 3 4)))) + (check-compiler-invariants r-sum-2-scalar) + (flat:check-tensor-equal? (force/eval r-sum-2-scalar) 405.0) (define r1-td (tensor 3.0 4.0 5.0)) (define r2-td (tp-reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) @@ -168,10 +265,13 @@ (λ (t) (build-tpromise (tpromise-shape t) (λ (_) 1.0)))) - (flat:check-tensor-equal? (force/eval (d-sqr r1-td (one-like r1-td))) + (define tcomp-dsqr-r1 (d-sqr r1-td (one-like r1-td))) + (check-compiler-invariants tcomp-dsqr-r1) + (flat:check-tensor-equal? (force/eval tcomp-dsqr-r1) (flat:tensor 6.0 8.0 10.0)) (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-compiler-invariants gsqr) (flat:check-tensor-equal? (force/eval gsqr) (flat:reshape '(2 3) @@ -179,15 +279,25 @@ (define d+ (tp-ext2-∇ +ᵈ 0 0 scalar-shape)) + (let-values (((da db) (d+ 2.0 3.0 1.0))) + (check-compiler-invariants da) + (flat:check-tensor-equal? (force/eval da) 1.0) + (check-compiler-invariants db) + (flat:check-tensor-equal? (force/eval db) 1.0)) + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-compiler-invariants da) (flat:check-tensor-equal? (force/eval da) (flat:tensor 1.0 1.0 1.0)) + (check-compiler-invariants db) (flat:check-tensor-equal? (force/eval db) (flat:tensor 1.0 1.0 1.0))) (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-compiler-invariants da) (flat:check-tensor-equal? (force/eval da) (flat:tensor 2.0 2.0 2.0)) + (check-compiler-invariants db) (flat:check-tensor-equal? (force/eval db) (flat:reshape '(2 3) @@ -200,6 +310,8 @@ (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0)))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 2.0 3.0))) (flat:check-tensor-equal? (force/eval gu) (force/eval (tensor 2.0 3.0 4.0)))) @@ -212,10 +324,97 @@ (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) 1.0))) + (check-compiler-invariants gt) (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 1.0 1.0)))) (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0)))) + (check-compiler-invariants gt) (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0)))))) + (tensor 1.0 1.0 1.0))))) + ;; t and u must have the same shape + (define s2-f (lambda (t u) (tensor (sum t) (sum u)))) + (define s2-d + (λ (g0 g1 t it st u iu su vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g0 i (vref vz iz)) + (vset! g1 i (vref vz (+ iz 1)))))) + (define s2-∇ (tp-ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)))) + (let-values (((gt gu) (s2-∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0)))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) + (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 1.0 1.0))) + (flat:check-tensor-equal? (force/eval gu) (force/eval (tensor 1.0 1.0 1.0)))) + (let-values (((gt gu) (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 5.0 6.0 6.0) + (tensor 7.0 8.0 6.0)) + (tensor (tensor 8.0 7.0 6.0) + (tensor 6.0 5.0 6.0))) + (tensor (tensor (tensor 6.0 8.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 8.0 2.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 5.0 1.0 6.0))) + (tensor (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)))))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) + (flat:check-tensor-equal? (force/eval gt) (force/eval (tp-reshape '(3 2 3) (list->tpromise (make-list 18 1.0))))) + (flat:check-tensor-equal? (force/eval gu) (force/eval (tp-reshape '(3 2 3) (list->tpromise (make-list 18 1.0)))))) + + (define test-env-flat-scalar ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) (list (tensor 1.0) 3.0))) + (check-compiler-invariants test-env-flat-scalar) + + ;; Check common subexpression introduced by let is not repeated + ;; TODO: add a generic version of the next 2 tests in check-compiler-invariants + (define count-flat:tref + (λ (ls) + (cond + ((null? ls) 0) + ((pair? (car ls)) (+ (count-flat:tref (car ls)) + (count-flat:tref (cdr ls)))) + ((eqv? (car ls) 'flat:tref) + (add1 (count-flat:tref (cdr ls)))) + (else (count-flat:tref (cdr ls)))))) + (define test-common-subexpr + (let ((t (tp-tref (tensor 1 2 3) 0))) + (tensor t t))) + (let-values (((instrs locals env) (run-compiler + (compile-expr test-common-subexpr + (count-references + test-common-subexpr + (hasheq))) + '() '()))) + (check-equal? (count-flat:tref (make-instrs instrs locals env)) 1 + "Common subexpression containing flat:tref should occur once")) + (define test-common-nested-subexprs + (let ((t1 (tp-tref (tensor (tensor 1 2 3) (tensor 4 5 6)) 0))) + (let ((t0 (tp-tref t1 0))) + (tensor t0 t0)))) + (let-values (((instrs locals env) (run-compiler + (compile-expr test-common-nested-subexprs + (count-references + test-common-nested-subexprs + (hasheq))) + '() '()))) + (check-equal? (count-flat:tref (make-instrs instrs locals env)) 2 + "Common subexpressions containing flat:tref should occur twice")) + + (define random-tensor + (λ (s) + (build-tpromise s (λ (tidx) (random 10))))) + (define test-build-random + (let ((v (random-tensor '(3 2 4)))) + (*-ρ v v))) + (check-pred + (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) + (vector->list (flat:flat-store (force/eval test-build-random))))) diff --git a/malted/test/test-O-init.rkt b/malted/test/test-O-init.rkt index 1422455..95c18dc 100644 --- a/malted/test/test-O-init.rkt +++ b/malted/test/test-O-init.rkt @@ -2,18 +2,27 @@ (require rackunit) ;; TODO: Make this better. We musn't break abstraction boundaries (require "../lazy/tensors/0-lazy.rkt") - (define v (init-shape (list 1000 4))) - (define mean-v (force/eval (abs (/ (sum (sum v)) 4000)))) - (define variance-v (force/eval (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v)))) - (check-true (< mean-v 0.05)) - (check-true (and (>= variance-v 0.4) - (<= variance-v 0.6))) + + (define v (init-shape (list 10 4))) + (define mean-v + (abs (/ (sum (sum v)) 40))) + (define variance-v + (- (/ (sum (sum (* v v))) 40) (* mean-v mean-v))) + (check-true (< (force/eval mean-v) 0.05)) + (pretty-print (get-compiled variance-v)) + (check-true (let ((forced (force/eval variance-v))) + (and (>= forced 0.4) + (<= forced 0.6)))) ;; Here variance will be 2/8 = 0.25 - (define r (init-shape (list 1000 4 2))) - (define mean-r (force/eval (abs (/ (sum (sum (sum r))) 8000)))) - (define variance-r (force/eval (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r)))) + (define r (init-shape (list 10 4 2))) + (define mean-r (abs (/ (sum (sum (sum r))) 80))) + (define variance-r (- (/ (sum (sum (sum (* r r)))) 80) + (* mean-r mean-r))) - (check-true (< mean-r 0.05)) - (check-true (and (>= variance-r 0.22) - (<= variance-r 0.28)))) + (check-true (< (force/eval mean-r) 0.05)) + (pretty-print (get-compiled variance-r)) + (check-true (let ((forced (force/eval variance-r))) + (println forced) + (and (>= forced 0.22) + (<= forced 0.28))))) From 908e378587ce861264994f2f19bfae177261c439 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 20 Oct 2023 19:56:26 -0400 Subject: [PATCH 64/83] [add-lazy]Caching WIP --- Makefile | 5 + flat-tensors/ext-impl.rkt | 3 +- flat-tensors/tensors/D-extend.rkt | 2 +- lazy.rkt | 2 + lazy/autodiff/A-autodiff.rkt | 2 +- lazy/autodiff/B-prims.rkt | 26 +- lazy/autodiff/D-test-helpers.rkt | 7 +- lazy/autodiff/E-print.rkt | 3 +- lazy/ext-ops/test/test-A-scalar-ops.rkt | 6 +- lazy/tensors.rkt | 5 +- lazy/tensors/0-lazy.rkt | 473 +---------------- lazy/tensors/1-reflect.rkt | 56 ++ lazy/tensors/A-equality.rkt | 4 +- lazy/tensors/B-test-programs.rkt | 202 ++++++++ lazy/tensors/c0-ast.rkt | 74 +++ lazy/tensors/c1-racket-runtime.rkt | 84 +++ lazy/tensors/c2-interpreter.rkt | 98 ++++ lazy/tensors/c3-compiler.rkt | 660 ++++++++++++++++++++++++ lazy/tensors/test/test-0-lazy.rkt | 421 +-------------- lazy/tensors/test/test-1-reflect.rkt | 280 ++++++++++ lazy/tensors/test/test-A-equality.rkt | 1 + lazy/tensors/test/test-c3-compiler.rkt | 73 +++ malted/test/test-O-init.rkt | 26 +- 23 files changed, 1591 insertions(+), 922 deletions(-) create mode 100644 lazy/tensors/1-reflect.rkt create mode 100644 lazy/tensors/B-test-programs.rkt create mode 100644 lazy/tensors/c0-ast.rkt create mode 100644 lazy/tensors/c1-racket-runtime.rkt create mode 100644 lazy/tensors/c2-interpreter.rkt create mode 100644 lazy/tensors/c3-compiler.rkt create mode 100644 lazy/tensors/test/test-1-reflect.rkt create mode 100644 lazy/tensors/test/test-c3-compiler.rkt diff --git a/Makefile b/Makefile index a6fc7d0..25302c0 100644 --- a/Makefile +++ b/Makefile @@ -25,7 +25,12 @@ LAZY_AUTODIFF_DIR=$(LAZY_DIR)/autodiff LAZY_EXT_OPS_DIR=$(LAZY_DIR)/ext-ops LAZY_TENSORS_SOURCES=\ + $(LAZY_TENSORS_DIR)/c0-ast.rkt\ + $(LAZY_TENSORS_DIR)/c1-racket-runtime.rkt\ + $(LAZY_TENSORS_DIR)/c2-interpreter.rkt\ + $(LAZY_TENSORS_DIR)/c3-compiler.rkt\ $(LAZY_TENSORS_DIR)/0-lazy.rkt\ + $(LAZY_TENSORS_DIR)/1-reflect.rkt\ $(LAZY_TENSORS_DIR)/A-equality.rkt\ $(LAZY_DIR)/tensors.rkt diff --git a/flat-tensors/ext-impl.rkt b/flat-tensors/ext-impl.rkt index 8a47635..6eb7d34 100644 --- a/flat-tensors/ext-impl.rkt +++ b/flat-tensors/ext-impl.rkt @@ -15,7 +15,8 @@ functional->preallocated-2-ρ functional->preallocated-2-∇ idxs - scalarize)) + scalarize + ensure-flat)) (require (only-in "autodiff/E-print.rkt" make-printable-flat fake-tensor)) diff --git a/flat-tensors/tensors/D-extend.rkt b/flat-tensors/tensors/D-extend.rkt index 6670747..e2e5501 100644 --- a/flat-tensors/tensors/D-extend.rkt +++ b/flat-tensors/tensors/D-extend.rkt @@ -389,4 +389,4 @@ functional->preallocated-1-ρ functional->preallocated-1-∇ functional->preallocated-2-ρ functional->preallocated-2-∇ merge-shapes min-shape ext2-shapes idxs - flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize) + flat-ext1-∇ flat-ext1-ρ flat-ext2-ρ scalarize ensure-flat) diff --git a/lazy.rkt b/lazy.rkt index 59b0e9a..87664db 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -14,6 +14,8 @@ ext1-ρ ext2-ρ ext1-∇ ext2-∇ + print-compiler? + dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* ext1 ext2 prim1 prim2 diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index df57f32..e7399a1 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -73,7 +73,7 @@ (λ (y wrt) (let ((σ (∇σ y (hasheq)))) (map* (λ (d) - (force/eval (hash-ref σ d 0.0))) + (↓ (hash-ref σ d 0.0))) wrt)))) (define ∇σ diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index ff532af..ee39e10 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -9,12 +9,26 @@ (define ∇-function (λ (f) (f ∇-function))) -;;TODO: add more metadata to functions so that we know which function is being -;; passed to the extend functions. - (define shape-fn (λ (f) (f shape-fn))) +;;TODO: make prim1 and prim2 func-callable structures using the prop:procedure +;;struct property + +(struct prim (ρ-fn ∇-fn shape-fn + signature ;;autogenerate this before runtime to avoid + ;;changing this during runtime + proc ;; This will be the prim*-dual func + prealloc? ;; use this to redefine expects-preallocated? + ) + #:property prop:procedure (λ (this . args) + (apply (prim-proc this) args))) + +;;TODO: move expects-preallocated?, functional->preallocated-1-ρ, +;;functional->preallocated-1-∇, functional->preallocated-2-ρ, +;;functional->preallocated-2-∇ here because they depend on the representation of +;;prims + (define prim1 (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) (λ (daf) @@ -49,12 +63,6 @@ (rb (ρ db))) (dual (ρ-fn ra rb) (λ (d z σ) - #; - (let-values (((ga gb) (∇-fn ra rb z) - ;; TODO: define a force*-2 for this - #;(force*-2 z (lambda (z) (∇-fn ra rb z))))) - (let ((σ-hat ((κ da) da (tp-force ga) σ))) - ((κ db) db (tp-force gb) σ-hat))) (force*2 (λ () (∇-fn ra rb z)) (λ (ga gb) (let ((σ-hat ((κ da) da ga σ))) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index 6d11570..9da4331 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -1,19 +1,18 @@ #lang racket (require "../tensors.rkt") -(require (only-in "../tensors/0-lazy.rkt" force/eval)) (require "A-autodiff.ss") (require rackunit) (define forced-ρ (λ (d) - (force/eval (ρ d)))) + (↓ (ρ d)))) (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) - (let* ((y (force/eval (apply fn args))) - (g (force/eval (apply (∇¹ fn) args))) + (let* ((y (↓ (apply fn args))) + (g (↓ (apply (∇¹ fn) args))) (ans-ρ (ρ ans))) (cond ((and (equal-wt? ans-ρ (ρ y)) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 9ae1afa..270083d 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -2,6 +2,7 @@ (require "A-autodiff.rkt") (require "../tensors/0-lazy.rkt") +(require "../tensors/1-reflect.rkt") (require (except-in "../../flat-tensors/ext-impl.rkt" scalarize)) (define max-tensor-print-length (make-parameter 5)) @@ -11,7 +12,7 @@ (cond ((dual? y) (make-printable (ρ y))) ((tpromise? y) - (make-printable (force/eval y) max-length)) + (make-printable (↓ y) max-length)) ((flat? y) (make-printable-flat y max-length)) ((list? y) (map (λ (le) (make-printable le max-length)) y)) diff --git a/lazy/ext-ops/test/test-A-scalar-ops.rkt b/lazy/ext-ops/test/test-A-scalar-ops.rkt index 2c13e39..4a5f87e 100644 --- a/lazy/ext-ops/test/test-A-scalar-ops.rkt +++ b/lazy/ext-ops/test/test-A-scalar-ops.rkt @@ -1,12 +1,13 @@ (module+ test (require rackunit) - (require (only-in "../tensors.rkt" tensor)) + (require (only-in "../tensors.rkt" tensor print-compiler?)) ;; Check basic numericals (let ((a 2) (b 3)) (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) - (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) + (parameterize ((print-compiler? '(Cache-Hit))) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0))) (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) (check-ρ-∇ (d/ a b) 2/3 @@ -35,7 +36,6 @@ (check-dual-equal? ((∇¹ z) a a) (list 2.5 2.0))) ;; Check numericals with vector-duals - (let ((a (tensor 2.0 3.0 4.0)) (b (tensor 3.0 8.0 9.0))) (check-ρ-∇ (d+ a b) (tensor 5.0 11.0 13.0) (list (tensor 1.0 1.0 1.0) (tensor 1.0 1.0 1.0))) diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index dbbf905..21db9f5 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -1,5 +1,6 @@ #lang racket (require "tensors/0-lazy.rkt") +(require "tensors/1-reflect.rkt") (require "tensors/A-equality.rkt") (provide start-vector-manager vector-manager-report) @@ -11,7 +12,9 @@ (provide ext1-ρ ext2-ρ ext1-∇ ext2-∇) -(provide force/eval scalarize) +(provide ↓ scalarize) + +(provide print-compiler?) ;; These will get overriden by duals (provide tensor?) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index f189f01..e823ff7 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -1,6 +1,10 @@ #lang racket (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + +(require "c0-ast.rkt") +(require (only-in "c1-racket-runtime.rkt" ext2-∇-result)) + #| Questions: @@ -43,58 +47,6 @@ instructions refering to the same gensym variable |# -;; tensor computations -(struct tcomp () #:transparent) -#; -(: lst (U (Listof tpromise) (Listof Number))) -(struct tcomp-list->tensor tcomp (lst) #:transparent) -#; -(: s (Listof Natural)) ;; non-empty -#; -(: f (-> (Listof Natural) Number)) -(struct tcomp-build-tensor tcomp (s f) #:transparent) -#; -(: tp tpromise) -#; -(: i Natural) -(struct tcomp-tref tcomp (tp i) #:transparent) -#; -(: tp tpromise) -#; -(: i (Listof Natural)) -(struct tcomp-trefs tcomp (tp b) #:transparent) -#; -(: fᵈ (U (-> Number Number (Values Number Number)) - (-> (Vector Number) Natural (Listof Natural) - (Vector Number) Natural (Listof Natural) - (Vector Number) Natural (Listof Natural)))) -(struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - #:transparent) -(struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) -(struct tcomp-ext2-ρ-scalar tcomp (f tp-t tp-u) #:transparent) -(struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) -(struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) -(struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) -(struct tcomp-reshape tcomp (s tp) #:transparent) - -(struct tpromise ((tensor #:mutable) shape) - #:guard - (λ (tensor shape name) - (unless (or (flat? tensor) (tcomp? tensor)) - (error 'make-tpromise - (string-append - "First argument must be either a" - " tcomp or a flat tensor. Got ~a") - tensor)) - (unless ((listof positive-integer?) shape) - (error 'make-tpromise - (string-append - "Second argument must be a list" - " of positive integers. Got ~a") - shape)) - (values tensor shape)) - #:transparent) - #; (: scalar? (-> Any Boolean)) (define scalar? number?) @@ -157,318 +109,6 @@ instructions refering to the same gensym variable (new-shape (cons outer inner-shape))) (tpromise inner-flat new-shape))))))) -#; -(: ensure-flat (-> (U flat Number) flat)) -(define ensure-flat - (λ (v) - (cond - ((scalar? v) (flat '() (vec v) 0)) - (else v)))) - -(define-namespace-anchor a) - -(define make-instrs - (λ (instrs locals env) - (let* ([local-binds (for/fold ((binds '())) - ((l/instrs locals)) - `((,(car l/instrs) ,(cdr l/instrs)) . ,binds))] - [env-binds (for/fold ((binds '())) - ((t/instrs env)) - `((,(car t/instrs) ,(cdr t/instrs)) . ,binds))]) - `(let* ,env-binds - (let* ,local-binds - ,instrs))))) - -(define run-instrs - (lambda (instrs locals env) - (let ([static-env (namespace-anchor->namespace a)]) - (eval (make-instrs instrs locals env) static-env)))) - -(define force/eval - (lambda (tp (print? #f)) - (when print? - (printf "~n####PP tensor: ") - (pretty-print tp)) - (match tp - [(tpromise t _) - #:when (or (flat? t) (scalar? t) (tcomp? t)) - - (let-values (((instrs locals env) - (run-compiler (compile-expr t (count-references t (hasheq))) - '() - '()))) - (let ((res (run-instrs instrs locals env))) - (set-tpromise-tensor! tp res) - res))] - ;; NOTE: This case runs when we use tp-scalarize to turn - ;; the tensor to a scalar - (_ tp)))) - -(struct counter-data (binding-name - ref-count - (compiled? #:mutable #:auto)) - #:transparent) - -;; Count references so that the tcomp AST nodes that refer to the same memory -;; location i.e. common AST nodes get extracted by let-binding them in the -;; compiled output racket code. -(define count-tcomp-references - (λ (tc counter) - (match-let (((counter-data tc-binding-name tc-ref-count _) - (hash-ref counter tc - (λ () - (counter-data (gensym 'local) 0))))) - (let ((counter^ (hash-set counter tc - (counter-data tc-binding-name - (add1 tc-ref-count))))) - (match tc - [(tcomp-list->tensor lst) - (for/fold - ((counter^^ counter^)) - ((l lst)) - (count-references l counter^^))] - [(tcomp-build-tensor s f) counter^] - [(tcomp-tref tp i) - (count-references tp counter^)] - [(tcomp-trefs tp b) - (count-references tp counter^)] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (count-references - tp-z - (count-references - tp-t1 - (count-references tp-t0 counter^)))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (count-references - zp - (count-references tp counter^))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (count-references - tp-u - (count-references tp-t counter^))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (count-references - tp-u - (count-references tp-t counter^))] - [(tcomp-ext1-ρ-scalar f tp) - (count-references tp counter^)] - [(tcomp-ext1-ρ f m shape-fn tp) - (count-references tp counter^)] - [(tcomp-reshape s tp) - (count-references tp counter^)]))))) - -(define count-references - (λ (t counter) - (match t - ((tpromise tc _) - (count-references tc counter)) - ((tcomp) (count-tcomp-references t counter)) - (_ counter)))) - -;; TODO: Add tcomp nodes for let and var so that common refs to tcomp can become -;; tcomp-vars and be bound in a tcomp-let. -;; -;; TODO: Add another intermediate pass that generates tcomp-let and tcomp-var -;; nodes in such a way that it converts the abstract syntax graph to a abstract -;; syntax tree. -#| -(tcomp ..) => - (tcomp-let ((var (tcomp ...))) (tcomp ... var ....)) - -(tpromise (tcomp ... (tcomp-var name) ...)) -|# - -(define extend-env - (λ (k v env) - `((,k . ,v) . ,env))) -(define extend-locals extend-env) - -(define exists-in-env? - (λ (ft env) - (match env - ('() #f) - (`((,k . ,v) . ,_) #:when (eq? ft v) k) - (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) - -;; (: run-compiler (All (Instrs Env) (-> Env (Values Instrs Env)))) -(struct Compiler (run-compiler) #:transparent) - -(define run-compiler - (λ (c locals env) - ((Compiler-run-compiler c) locals env))) - -(define inj-compiler-val - (λ (v) - (Compiler (λ (locals env) (values v locals env))))) - -;; TODO: In the future scalars must be in the environment rather than the -;; compiled instrs. This will make the instruction signature fully independent -;; of the data in the environmnt. -(define inj-compiler-flat - (λ (ft) - (Compiler - (λ (locals env) - (cond - ((exists-in-env? ft env) => (lambda (var) (values var locals env))) - ((and (flat? ft) (null? (flat-shape ft))) - (values ft locals env)) - (else - (let ((new-var (gensym 'ft))) - (values new-var locals (extend-env new-var ft env))))))))) - -(define inj-compiler-instrs - (λ (instrs cd) - (match-let (((counter-data binding-var ref-count _) cd)) - (Compiler (λ (locals env) - (pretty-print instrs) - (cond - ((<= ref-count 1) - (println "ref <= 1") - (values instrs locals env)) - ((assv binding-var locals) - (println "assv") - (values binding-var locals env)) - (else - (println "else") - (values binding-var - (extend-locals binding-var instrs locals) - env)))))))) - -(define ->c - (λ (c f) - (Compiler (λ (locals env) - (let-values (((instrs locals^ env^) (run-compiler c locals env))) - (run-compiler (f instrs) locals^ env^)))))) - - -(define compile-tcomp - (λ (tc counter) - (let ((tc-counter-data - (hash-ref counter tc - (λ () - (counter-data (gensym 'illegal) 0))))) - (match tc - [(tcomp-list->tensor lst) - (let ((instrs-list-compiler - (for/foldr ((list-compiler (inj-compiler-val '()))) - ((arg lst)) - (->c - (compile-expr arg counter) - (λ (instrs) - (->c - list-compiler - (λ (instrs-list) - (inj-compiler-val (cons instrs instrs-list))))))))) - (->c - instrs-list-compiler - (λ (instrs-list) - (inj-compiler-instrs `(flat:list->tensor (list ,@instrs-list)) - tc-counter-data))))] - [(tcomp-build-tensor s f) - (inj-compiler-flat (flat:build-tensor s f))] - [(tcomp-tref tp i) - (->c - (compile-expr tp counter) - (λ (instrs) - (inj-compiler-instrs `(flat:tref ,instrs ,i) tc-counter-data)))] - [(tcomp-trefs tp b) - (->c - (compile-expr tp counter) - (λ (instrs) - (inj-compiler-instrs `(flat:trefs ,instrs ',b) tc-counter-data)))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (->c - (compile-expr tp-t0 counter) - (λ (t0-instrs) - (->c - (compile-expr tp-t1 counter) - (λ (t1-instrs) - (->c - (compile-expr tp-z counter) - (λ (z-instrs) - (inj-compiler-instrs - `(let* ([b (if (zero? ,i) ,out0 ,out1)] - [v (ext2-∇-result-res b)]) - (cond - ((eqv? v 'uncalculated) - (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn - ,t0-instrs ,t1-instrs - ,z-instrs ,out0 ,out1) - (ext2-∇-result-res b)) - (else v))) - tc-counter-data)))))))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (->c - (compile-expr tp counter) - (λ (t-instrs) - (->c - (compile-expr zp counter) - (λ (z-instrs) - (inj-compiler-instrs - `(scalarize - (flat-ext1-∇ ,f ,m ,shape-fn - (ensure-flat ,t-instrs) - (ensure-flat ,z-instrs))) - tc-counter-data)))))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (->c - (compile-expr tp-t counter) - (λ (t-instrs) - (->c - (compile-expr tp-u counter) - (λ (u-instrs) - (inj-compiler-instrs - `(,f ,t-instrs ,u-instrs) - tc-counter-data)))))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (->c - (compile-expr tp-t counter) - (λ (t-instrs) - (->c - (compile-expr tp-u counter) - (λ (u-instrs) - (inj-compiler-instrs - `(scalarize - (flat-ext2-ρ ,f ,m ,n ,shape-fn - (ensure-flat ,t-instrs) - (ensure-flat ,u-instrs))) - tc-counter-data)))))] - [(tcomp-ext1-ρ-scalar f tp) - (->c - (compile-expr tp counter) - (λ (instrs) - (inj-compiler-instrs `(,f ,instrs) - tc-counter-data)))] - [(tcomp-ext1-ρ f m shape-fn tp) - (->c - (compile-expr tp counter) - (λ (instrs) - (inj-compiler-instrs `(scalarize - (flat-ext1-ρ ,f ,m ,shape-fn - (ensure-flat ,instrs))) - tc-counter-data)))] - [(tcomp-reshape s tp) - (->c - (compile-expr tp counter) - (λ (instrs) - (inj-compiler-instrs `(flat ',s - (flat-store ,instrs) - (flat-offset ,instrs)) - tc-counter-data)))])))) - -(define print-compiler? (make-parameter #f)) -(define compile-expr - (λ (tc counter) - (when (print-compiler?) - (pretty-print tc)) - (match tc - [(tpromise tc _) (compile-expr tc counter)] - [tc #:when (flat? tc) - (inj-compiler-flat tc)] - [tc #:when (or (pair? tc) (scalar? tc)) - (inj-compiler-val tc)] - [(tcomp) (compile-tcomp tc counter)]))) - (define bounded-idx*^ (λ (shape idx*) (match `(,shape ,idx*) @@ -598,17 +238,6 @@ instructions refering to the same gensym variable (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) -;; We may have to replace tp-scalarize with scalarize from flat-tensors, because -;; the force/eval used in its definition is undesirable. -(define tp-scalarize - (λ (tp) - (cond - [(and (tpromise? tp) (null? (tpromise-shape tp))) - (tp-scalarize (force/eval tp))] - [(and (flat? tp) (null? (flat-shape tp))) - (vref (flat-store tp) 0)] - [else tp]))) - (define scalar-shape (λ (s0 [s1 '()]) '())) @@ -653,62 +282,7 @@ instructions refering to the same gensym variable (ensure-tpromise tp-u) (ensure-tpromise tp-z)))]))))) -(define ext2-∇-forcer - (λ (fᵈ r0 r1 shape-fn t0 t1 z out0 out1) - (let* ((f0 (ensure-flat t0)) - (f1 (ensure-flat t1)) - (fz (ensure-flat z)) - - (s0 (flat-shape f0)) - (sf0 (min-shape r0 s0)) - (stride0 (flat:size-of sf0)) - - (s1 (flat-shape t1)) - (sf1 (min-shape r1 s1)) - (stride1 (flat:size-of sf1)) - - (sf-z (shape-fn sf0 sf1)) - (stride-z (flat:size-of sf-z)) - - (v0 (flat-store f0)) - (v1 (flat-store f1)) - (vz (flat-store fz)) - - (off0 (flat-offset f0)) - (off1 (flat-offset f1)) - (offz (flat-offset fz))) - (ext2-shapes - s0 s1 r0 r1 sf-z - (λ (sz size-z q0 q1 strides) - (let ((g0 (new-vec (flat:size-of - s0) - 0.0)) - (g1 (new-vec (flat:size-of - s1) - 0.0))) - (for ([iz (in-range - 0 - size-z - stride-z)]) - (let-values (((i0 i1) - (idxs - strides - iz - off0 - off1))) - (fᵈ g0 g1 v0 i0 - stride0 - v1 i1 - stride1 - vz - (+ offz iz) - stride-z))) - (set-ext2-∇-result-res! out0 - (tp-scalarize (flat s0 g0 0))) - (set-ext2-∇-result-res! out1 - (tp-scalarize (flat s1 g1 0))))))))) - -(struct ext2-∇-result (res) #:mutable #:transparent) + (define tp-d-ext2^ (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((out0 (ext2-∇-result 'uncalculated)) @@ -740,35 +314,7 @@ instructions refering to the same gensym variable (lambda (tp) (or (tpromise? tp) (flat? tp) (scalar? tp)))) -(define get-compiled - (λ (t) - (let-values (((instrs locals env) - (run-compiler - (compile-expr t - (count-references t (hasheq))) - '() '()))) - (make-instrs instrs locals env)))) -(define run-compiled - (λ (t) - (let-values (((instrs locals env) - (run-compiler - (compile-expr t - (count-references t (hasheq))) - '() '()))) - (make-instrs instrs locals env)))) -;; TODO: may have to remove call to force*1 & force*2 so that this can be -;; handled at compile time -(define force*1 - (λ (t f) - (f (force/eval t)))) - -(define force*2 - (λ (ts f) - (let-values (((t1 t2) (ts))) - (f (force/eval t1) (force/eval t2))))) - -;; TODO: uncomment -;(include "test/test-0-lazy.rkt") +(include "test/test-0-lazy.rkt") (provide start-vector-manager vector-manager-report) @@ -777,10 +323,8 @@ instructions refering to the same gensym variable (flat:ref ref) (flat:refr refr))) (provide tensor - force/eval tpromise? (rename-out - (tp-scalarize scalarize) (tp-tref tref) (tp-tlen tlen) (list->tpromise list->tensor) @@ -800,8 +344,3 @@ instructions refering to the same gensym variable (tp-shape shape) (tp-reshape reshape) (flat:size-of size-of))) - -(provide force*1 force*2) - -;; TODO: delete later. For debugging only -(provide print-compiler? compile-expr make-instrs run-instrs force/eval tpromise-tensor run-compiler count-references get-compiled) diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt new file mode 100644 index 0000000..313c33c --- /dev/null +++ b/lazy/tensors/1-reflect.rkt @@ -0,0 +1,56 @@ +#lang racket +(require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require "c0-ast.rkt") +(require (only-in "c3-compiler.rkt" + print-compiler? + get-compiled + compile-tensor)) +(require (only-in "c2-interpreter.rkt" interp-racket)) + +(define ↓ + (lambda (tp (print? #f)) + (when print? + (printf "~n####PP tensor: ") + (pretty-print tp)) + (match tp + [(tpromise t _) + #:when (or (flat:flat? t) (number? t) (tcomp? t)) + + (let-values (((instrs env) + (compile-tensor t))) + (let ((res (interp-racket instrs env))) + (set-tpromise-tensor! tp res) + res))] + ;; NOTE: This case runs when we use tp-scalarize to turn + ;; the tensor to a scalar + (_ tp)))) + +;; We may have to replace tp-scalarize with scalarize from flat-tensors, because +;; the ↓ used in its definition is undesirable. +(define tp-scalarize + (λ (tp) + (cond + [(and (tpromise? tp) (null? (tpromise-shape tp))) + (tp-scalarize (↓ tp))] + [(and (flat:flat? tp) (null? (flat:flat-shape tp))) + (vector-ref (flat:flat-store tp) 0)] + [else tp]))) + +;; TODO: these force functions will be moved to the openCL runtime +(define force*1 + (λ (t f) + (f (↓ t)))) + +(define force*2 + (λ (ts f) + (let-values (((t1 t2) (ts))) + (f (↓ t1) (↓ t2))))) + +(include "test/test-1-reflect.rkt") + +(provide ↓ force*1 force*2) + +(provide print-compiler? get-compiled + (rename-out + (tp-scalarize scalarize))) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt index a03fa26..7c420ca 100644 --- a/lazy/tensors/A-equality.rkt +++ b/lazy/tensors/A-equality.rkt @@ -1,11 +1,11 @@ #lang racket -(require "0-lazy.rkt") +(require "1-reflect.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (define tp-tensor-equal? (λ (tp-actual tp-expected) - (flat:tensor-equal? (force/eval tp-actual) (force/eval tp-expected)))) + (flat:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) (require rackunit) (define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt new file mode 100644 index 0000000..c01e5b6 --- /dev/null +++ b/lazy/tensors/B-test-programs.rkt @@ -0,0 +1,202 @@ +#lang racket +(require "0-lazy.rkt") +(require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + +(define make-tref-test-program + (λ (t) + (tref t 2))) +(define make-list->tensor-test-program + (λ (l) + (list->tensor l))) + +(struct test-program-data (prog-thunk eval-res) #:transparent) +(define test-programs + (hasheqv + 'tensor-r1-0 (test-program-data + (λ () + (tensor 1 2 3)) + (flat:tensor 1 2 3)) + 'tensor-r1-1 (test-program-data + (λ () + (tensor 1 2 3 4 5)) + (flat:tensor 1 2 3 4 5)) + 'tensor-r2-0 (test-program-data + (λ () + (tensor (tensor 1 2 3) (tensor 4 5 6))) + (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6))))) +(define get-test-program + (λ (name) + ((test-program-data-prog-thunk (hash-ref test-programs name))))) +(define get-test-eval-res + (λ (name) + (test-program-data-eval-res (hash-ref test-programs name)))) + +(define test-tcomp-tref (make-tref-test-program (get-test-program 'tensor-r1-0))) +(define test-tcomp-tref-nested (tref (tref (get-test-program 'tensor-r2-0) 0) 2)) +(define test-list->tensor (make-list->tensor-test-program '(5 6 7 8))) +(define test-nested-list->tensor + (list->tensor `(,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0)))) +(define test-build-shape '(4 3)) +(define test-built-tensor (build-tensor test-build-shape + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column))))) +(define test-refs '(0 2)) +(define test-trefs (trefs test-built-tensor test-refs)) +(define test-reshape (reshape '(3 2 1) (trefs test-built-tensor '(1 3)))) + +(define sum-f + (λ (in-v iᵢ sᵢ out-v iₒ sₒ) + (vset! out-v iₒ + (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) + (+ sum (vref in-v i)))))) + +(define sum (ext1-ρ sum-f 1)) +(define test-tp-sum (sum (get-test-program 'tensor-r2-0))) +(define test-tp-sum-nested (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) + +(define id-f (lambda (v) v)) +(define id-ρ (ext1-ρ id-f 1 (λ (s) s))) +(define test-tp-id (id-ρ (get-test-program 'tensor-r2-0))) +(define test-tp-id-scalar (id-ρ (sum (tensor 4 5 6)))) + +(define t0 + (build-tensor '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) +(define *-ρ (ext2-ρ * 0 0)) +(define t0sqr (*-ρ t0 t0)) + +(define *-2-1-f + (λ (v0 i0 s0 v1 i1 s1 vout iout sout) + (for ([j0 (in-range 0 s0)]) + (vset! vout (+ iout j0) + (* (vref v0 (+ i0 j0)) + (vref v1 (+ i1 (modulo j0 s1)))))))) + +(define t1 + (build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))) + +(define t2 + (build-tensor '(6) + (λ (i) (* 3.0 (car i))))) + +(define *-2-1 + (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + +(define r-1-2 + (*-2-1 t1 t2)) + +(define t3 + (build-tensor '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + +(define t4 + (build-tensor '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y))))) + +(define r-3-4 + (*-2-1 t3 t4)) + +(define r-sum-2-scalar (*-ρ (sum t2) (sum (tensor 2 3 4)))) + +(define r1-td (tensor 3.0 4.0 5.0)) +(define r2-td (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) + +(define +ᶠ +) +(define +ᵈ (λ (a b z) (values z z))) + +(define sqrᶠ (λ (a) (* a a))) +(define sqrᵈ + (λ (a z) (* z 2 a))) + +(define d-sqr (ext1-∇ sqrᵈ 0)) + +(define one-like + (λ (t) + (build-tensor (shape t) (λ (_) 1.0)))) + +(define tcomp-dsqr-r1 (d-sqr r1-td (one-like r1-td))) + +(define d+ (ext2-∇ +ᵈ 0 0)) + +(define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) 0 0)) + +(define sum-1-∇ + (λ (g t it st vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g i (vref vz iz))))) + +(define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()))) + +;; t and u must have the same shape +(define s2-f (lambda (t u) (tensor (sum t) (sum u)))) +(define s2-d + (λ (g0 g1 t it st u iu su vz iz sz) + (for* ([i (in-range it (+ it st))]) + (vset! g0 i (vref vz iz)) + (vset! g1 i (vref vz (+ iz 1)))))) +(define s2-∇ (ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)))) + +(define test-env-flat-scalar + ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) + (list (tensor 1.0) 3.0))) + +;; Check common subexpression introduced by let is not repeated +(define test-common-subexpr + (let ((t (tref (tensor 1 2 3) 0))) + (tensor t t))) + +(define test-common-nested-subexprs + (let ((t1 (tref (tensor (tensor 1 2 3) (tensor 4 5 6)) 0))) + (let ((t0 (tref t1 0))) + (tensor t0 t0)))) + +(define random-tensor + (λ (s) + (build-tensor s (λ (tidx) (random 10))))) +(define test-build-random + (let ((v (random-tensor '(3 2 4)))) + (*-ρ v v))) + +(define +-ρ (ext2-ρ + 0 0)) +(define /-ρ (ext2-ρ / 0 0)) +(define --ρ (ext2-ρ - 0 0)) +(define abs-ρ (ext1-ρ abs 0)) + +(define mean + (λ (t) + (abs-ρ (/-ρ (sum (sum t)) (size-of (shape t)))))) +(define variance + (λ (t) + (--ρ (/-ρ (sum (sum (*-ρ t t))) (size-of (shape t))) + (*-ρ (mean t) (mean t))))) + +(define v (random-tensor '(10 4))) +(define mean-v (mean v)) +(define variance-v (variance v)) + +(define r (random-tensor '(10 4 2))) +(define mean-r (mean r)) +(define variance-r (variance r)) + +(define -ᶠ -) +(define -ᵈ (λ (a b z) (values z (- z)))) +(define d- (ext2-∇ -ᵈ 0 0)) + + +(provide (all-defined-out)) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt new file mode 100644 index 0000000..1cf0c7c --- /dev/null +++ b/lazy/tensors/c0-ast.rkt @@ -0,0 +1,74 @@ +#lang racket +(require "../../flat-tensors/ext-impl.rkt") + +;; tensor computations +(struct tcomp () #:transparent) +#; +(: lst (U (Listof tpromise) (Listof Number))) +(struct tcomp-list->tensor tcomp (lst) #:transparent) +#; +(: s (Listof Natural)) ;; non-empty +#; +(: f (-> (Listof Natural) Number)) +(struct tcomp-build-tensor tcomp (s f) #:transparent) +#; +(: tp tpromise) +#; +(: i Natural) +(struct tcomp-tref tcomp (tp i) #:transparent) +#; +(: tp tpromise) +#; +(: i (Listof Natural)) +(struct tcomp-trefs tcomp (tp b) #:transparent) +#; +(: fᵈ (U (-> Number Number (Values Number Number)) + (-> (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural) + (Vector Number) Natural (Listof Natural)))) +(struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f tp-t tp-u) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + #:transparent) +(struct tcomp-reshape tcomp (s tp) #:transparent) +(struct tcomp-let tcomp (lhs rhs body) #:transparent) +(struct tcomp-var tcomp (name) #:transparent) +(struct tcomp-ds-ref tcomp (index) #:transparent) + +(struct tpromise ((tensor #:mutable) shape) + #:guard + (λ (tensor shape name) + (unless (or (flat? tensor) (tcomp? tensor)) + (error 'make-tpromise + (string-append + "First argument must be either a" + " tcomp or a flat tensor. Got ~a") + tensor)) + (unless ((listof positive-integer?) shape) + (error 'make-tpromise + (string-append + "Second argument must be a list" + " of positive integers. Got ~a") + shape)) + (values tensor shape)) + #:transparent) + +(provide (struct-out tcomp) + (struct-out tcomp-list->tensor) + (struct-out tcomp-build-tensor) + (struct-out tcomp-tref) + (struct-out tcomp-trefs) + (struct-out tcomp-ext1-ρ-scalar) + (struct-out tcomp-ext1-ρ) + (struct-out tcomp-ext2-ρ-scalar) + (struct-out tcomp-ext2-ρ) + (struct-out tcomp-ext1-∇) + (struct-out tcomp-ext2-∇) + (struct-out tcomp-reshape) + (struct-out tcomp-let) + (struct-out tcomp-var) + (struct-out tcomp-ds-ref) + (struct-out tpromise)) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt new file mode 100644 index 0000000..62db3a8 --- /dev/null +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -0,0 +1,84 @@ +#lang racket + +(require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + +(struct ext2-∇-result (res) #:mutable #:transparent) + +(define ext2-∇-forcer + (λ (fᵈ r0 r1 shape-fn t0 t1 z out0 out1) + (let* ((f0 (ensure-flat t0)) + (f1 (ensure-flat t1)) + (fz (ensure-flat z)) + + (s0 (flat-shape f0)) + (sf0 (min-shape r0 s0)) + (stride0 (flat:size-of sf0)) + + (s1 (flat-shape t1)) + (sf1 (min-shape r1 s1)) + (stride1 (flat:size-of sf1)) + + (sf-z (shape-fn sf0 sf1)) + (stride-z (flat:size-of sf-z)) + + (v0 (flat-store f0)) + (v1 (flat-store f1)) + (vz (flat-store fz)) + + (off0 (flat-offset f0)) + (off1 (flat-offset f1)) + (offz (flat-offset fz))) + (ext2-shapes + s0 s1 r0 r1 sf-z + (λ (sz size-z q0 q1 strides) + (let ((g0 (new-vec (flat:size-of + s0) + 0.0)) + (g1 (new-vec (flat:size-of + s1) + 0.0))) + (for ([iz (in-range + 0 + size-z + stride-z)]) + (let-values (((i0 i1) + (idxs + strides + iz + off0 + off1))) + (fᵈ g0 g1 v0 i0 + stride0 + v1 i1 + stride1 + vz + (+ offz iz) + stride-z))) + (set-ext2-∇-result-res! out0 + (scalarize (flat s0 g0 0))) + (set-ext2-∇-result-res! out1 + (scalarize (flat s1 g1 0))))))))) + +(define rt:trefs + (λ (ft b) + (cond + ((= (flat:rank b) 1) (flat:trefs ft (vector->list (flat-store b)))) + (else (error 'trefs-err "~a should be a tensor¹" b))))) + +(define data-segment + (make-parameter #f)) + +(define data-segment-ref + (λ (i) + (vector-ref (data-segment) i))) + +(define-namespace-anchor a) +(define runtime + ;;TODO explicitly declare the names being included in this namespace + (namespace-anchor->namespace a)) + +(provide runtime flat? flat:build-tensor flat:list->tensor + flat:tref rt:trefs (struct-out ext2-∇-result) + ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ + flat flat-store flat-offset flat-ext1-ρ) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt new file mode 100644 index 0000000..3c7939a --- /dev/null +++ b/lazy/tensors/c2-interpreter.rkt @@ -0,0 +1,98 @@ +#lang racket + +(require "c0-ast.rkt") +(require (only-in "c1-racket-runtime.rkt" + runtime flat? flat:build-tensor flat:list->tensor + flat:tref rt:trefs ext2-∇-result-res ext2-∇-forcer + scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store + flat-offset flat-ext1-ρ)) + +(define make-instrs + (λ (instrs ds) + `(begin + (data-segment ,ds) + ,instrs))) + +(define interp-tensor-tcomp + (λ (tc env ds) + (match tc + [(tcomp-list->tensor lst) + (let ((eval-list + (for/list ((arg lst)) + (interp-tensor-expr arg env ds)))) + (flat:list->tensor eval-list))] + [(tcomp-build-tensor s f) + (flat:build-tensor s f)] + [(tcomp-tref tp i) + (flat:tref (interp-tensor-expr tp env ds) + (interp-tensor-expr i env ds))] + [(tcomp-trefs tp b) + (rt:trefs (interp-tensor-expr tp env ds) + (interp-tensor-expr b env ds))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let* ([b (if (zero? i) out0 out1)] + [v (ext2-∇-result-res b)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer fᵈ r0 r1 shape-fn + (interp-tensor-expr tp-t0 env ds) + (interp-tensor-expr tp-t1 env ds) + (interp-tensor-expr tp-z env ds) + out0 out1) + (ext2-∇-result-res b)) + (else v)))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (scalarize + (flat-ext1-∇ f m shape-fn + (ensure-flat (interp-tensor-expr tp env ds)) + (ensure-flat (interp-tensor-expr zp env ds))))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (f (interp-tensor-expr tp-t env ds) (interp-tensor-expr tp-u env ds))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (scalarize + (flat-ext2-ρ f m n shape-fn + (ensure-flat (interp-tensor-expr tp-t env ds)) + (ensure-flat (interp-tensor-expr tp-u env ds))))] + [(tcomp-ext1-ρ-scalar f tp) + (f (interp-tensor-expr tp env ds))] + [(tcomp-ext1-ρ f m shape-fn tp) + (scalarize + (flat-ext1-ρ f m shape-fn + (ensure-flat (interp-tensor-expr tp env ds))))] + [(tcomp-reshape s tp) + (flat s + (flat-store (interp-tensor-expr tp env ds)) + (flat-offset (interp-tensor-expr tp env ds)))] + [(tcomp-let lhs rhs body) + (interp-tensor-expr + body + (cons + (cons lhs + (interp-tensor-expr rhs env ds)) + env) + ds)] + [(tcomp-var name) + (cond + ((assv name env) + => + (λ (p) (cdr p))) + (else (error 'interpret-free "Free variable: ~a" name)))] + [(tcomp-ds-ref index) + (vector-ref ds index)]))) + +(define interp-tensor-expr + (λ (t env ds) + (match t + [(tpromise tc _) (interp-tensor-expr tc env ds)] + [v #:when (or (flat? v) (pair? v) (number? v)) v] + [(tcomp) (interp-tensor-tcomp t env ds)]))) + +(define interp-tensor + (λ (t ds) + (interp-tensor-expr t '() ds))) + +(define interp-racket + (lambda (instrs ds) + (eval (make-instrs instrs ds) runtime))) + +(provide interp-racket interp-tensor make-instrs) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt new file mode 100644 index 0000000..6899619 --- /dev/null +++ b/lazy/tensors/c3-compiler.rkt @@ -0,0 +1,660 @@ +#lang racket + +(require "c0-ast.rkt") +(require (only-in "c2-interpreter.rkt" make-instrs interp-tensor interp-racket)) +(require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require rackunit) +(require file/xxhash32) + +(struct counter-data (binding-name + ref-count) + #:transparent) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Compiler Passes +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +;;TODO: later eds and gs passes should not be needed because the tcomp AST nodes +;;should have a signature and dss field which will be populated at the time of +;;their instantiation. Then we just access those fields from the AST node rather +;;than computing them. Use a global data segment that has flat tensors used by all tcomp nodes in our program. + +;;Extracts the data segment which is a vector that contains scalars (arguments +;;to tref), flat tensors and flat tensor¹ of indices that will be the arguments +;;to trefs. +(define extract-data-segment + (λ (t) + (let-values (((t^ data-segment-stack) (eds-expr t '()))) + ;; convert data segment stack to data segment array + (values t^ (list->vector (reverse data-segment-stack)))))) + +(define eds-expr + (λ (t dss) + (match t + (s #:when (number? s) + (values s dss)) + (ft + #:when (flat? ft) + (cond + ((memq ft dss) + => + (λ (res) + (values (tcomp-ds-ref (length (cdr res))) dss))) + (else (values (tcomp-ds-ref (length dss)) (cons ft dss))))) + ((tpromise tc s) + (let-values (((tc^ dss^) (eds-expr tc dss))) + (cond + ((number? tc^) (values tc^ dss^)) + (else (values (tpromise tc^ s) dss^))))) + ((tcomp) (eds-tcomp t dss))))) + +(define eds-tcomp + (λ (tc dss) + (match tc + [(tcomp-list->tensor lst) + (let-values (((ts dss^) + (for/fold ((ts '()) + (dss^ dss)) + ((l lst)) + (let-values (((t dss^^) (eds-expr l dss^))) + (values (cons t ts) dss^^))))) + (values (tcomp-list->tensor (reverse ts)) dss^))] + [(tcomp-build-tensor s f) + (values tc dss)] + [(tcomp-tref tp i) + (let-values (((t dss^) (eds-expr tp dss))) + (values (tcomp-tref t (tcomp-ds-ref (length dss^))) + (cons i dss^)))] + [(tcomp-trefs tp b) + (let-values (((t dss^) (eds-expr tp dss))) + (values (tcomp-trefs t (tcomp-ds-ref (length dss^))) + (cons (flat:list->tensor b) dss^)))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let-values (((t0 dss^) (eds-expr tp-t0 dss))) + (let-values (((t1 dss^^) (eds-expr tp-t1 dss^))) + (let-values (((z dss^^^) (eds-expr tp-z dss^^))) + (values (tcomp-ext2-∇ fᵈ r0 r1 shape-fn t0 t1 z out0 out1 i) + dss^^^))))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (let-values (((zp^ dss^^) (eds-expr zp dss^))) + (values (tcomp-ext1-∇ tp^ zp^ f m shape-fn) dss^^)))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (let-values (((t dss^) (eds-expr tp-t dss))) + (let-values (((u dss^^) (eds-expr tp-u dss^))) + (values (tcomp-ext2-ρ-scalar f t u) dss^^)))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (let-values (((t dss^) (eds-expr tp-t dss))) + (let-values (((u dss^^) (eds-expr tp-u dss^))) + (values (tcomp-ext2-ρ t u f m n shape-fn) dss^^)))] + [(tcomp-ext1-ρ-scalar f tp) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-ext1-ρ-scalar f tp^) dss^))] + [(tcomp-ext1-ρ f m shape-fn tp) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-ext1-ρ f m shape-fn tp^) dss^))] + [(tcomp-reshape s tp) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-reshape s tp^) dss^))]))) + +(define hash-signatures? + (make-parameter #t)) +;;TODO: Optimize sign by replacing it with this commented function +#; +(define sign + (let ((xxh32-ctx (make-xxh32))) + (λ ss + (cond + ((hash-signatures?) + (xxh32-reset! xxh32-ctx 0) + (xxh32-update! xxh32-ctx (apply bytes-append ss)) + (xxh32-digest xxh32-ctx)) + (else (format "~a" ss)))))) +(define sign + (let ((xxh32-ctx (make-xxh32))) + (λ (s) + (cond + ((hash-signatures?) + (xxh32-reset! xxh32-ctx 0) + (xxh32-update! xxh32-ctx (string->bytes/utf-8 s)) + (format "~a" (xxh32-digest xxh32-ctx))) + (else (format "~a" s)))))) + +#; +(define generate-signature + (λ (t) + (gs-expr t '()))) +(define generate-signature + (λ (t) + (gs-expr t))) + +#; +(define gs-expr + (λ (t position) + (match t + (s #:when (number? s) + (sign (format "s~a~a" s position))) + ((tpromise tc _) (gs-expr tc (cons 0 position))) + ((tcomp) (gs-tcomp t position))))) +(define gs-expr + (λ (t) + (match t + (s #:when (number? s) + (sign (format "s~a" s))) + ((tpromise tc _) (gs-expr tc)) + ((tcomp) (gs-tcomp t))))) + +#; +(define gs-tcomp + (λ (tc position) + (match tc + [(tcomp-list->tensor lst) + (let ((list-sig + (for/fold ((sig "")) + ((l lst) + (i (in-naturals 0))) + (string-append sig (gs-expr l (cons i position)))))) + (sign (format "l>t~a~a" list-sig position)))] + [(tcomp-build-tensor s f) + (sign (format "bt~a~a~a" s f position))] + [(tcomp-tref tp i) + (sign + (format "tr~a~a~a" + (gs-expr tp (cons 0 position)) + (gs-expr i (cons 1 position)) + position))] + [(tcomp-trefs tp b) + (sign + (format "trs~a~a~a" + (gs-expr tp (cons 0 position)) + (gs-expr b (cons 1 position)) + position))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (sign + (format "e2∇~a~a_~a~a~a~a~a~a~a" + fᵈ r0 r1 shape-fn + (gs-expr tp-t0 (cons 0 position)) + (gs-expr tp-t1 (cons 1 position)) + (gs-expr tp-z (cons 2 position)) + i + position))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (sign + (format "e1∇~a~a~a~a~a~a" + f m shape-fn + (gs-expr tp (cons 0 position)) + (gs-expr zp (cons 1 position)) + position))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (sign + (format "e2ρs~a~a~a~a" + f + (gs-expr tp-t (cons 0 position)) + (gs-expr tp-u (cons 1 position)) + position))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (sign + (format "e2ρ~a~a_~a~a~a~a~a" + f m n shape-fn + (gs-expr tp-t (cons 0 position)) + (gs-expr tp-u (cons 1 position)) + position))] + [(tcomp-ext1-ρ-scalar f tp) + (sign + (format "e1ρs~a~a~a" + f + (gs-expr tp (cons 0 position)) + position))] + [(tcomp-ext1-ρ f m shape-fn tp) + (sign + (format "e1ρ~a~a~a~a~a" + f m shape-fn + (gs-expr tp (cons 0 position)) + position))] + [(tcomp-reshape s tp) + (sign + (format "r~a~a~a" + s (gs-expr tp (cons 0 position)) position))] + [(tcomp-ds-ref index) + (sign + (format "dsr~a~a" index position))]))) +(define gs-tcomp + (λ (tc) + (match tc + [(tcomp-list->tensor lst) + (let ((list-sig + (for/fold ((sig "")) + ((l lst) + (i (in-naturals 0))) + (string-append sig (gs-expr l))))) + (sign (format "l>t~a" list-sig)))] + [(tcomp-build-tensor s f) + (sign (format "bt~a~a" s f))] + [(tcomp-tref tp i) + (sign + (format "tr~a~a" (gs-expr tp ) (gs-expr i )))] + [(tcomp-trefs tp b) + (sign + (format "trs~a~a" (gs-expr tp ) (gs-expr b )))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (sign + (format "e2∇~a~a_~a~a~a~a~a~a" + fᵈ r0 r1 shape-fn + (gs-expr tp-t0 ) + (gs-expr tp-t1 ) + (gs-expr tp-z ) + i))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (sign + (format "e1∇~a~a~a~a~a" + f m shape-fn + (gs-expr tp ) + (gs-expr zp )))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (sign + (format "e2ρs~a~a~a" + f + (gs-expr tp-t ) + (gs-expr tp-u )))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (sign + (format "e2ρ~a~a_~a~a~a~a" + f m n shape-fn + (gs-expr tp-t ) + (gs-expr tp-u )))] + [(tcomp-ext1-ρ-scalar f tp) + (sign + (format "e1ρs~a~a" f (gs-expr tp )))] + [(tcomp-ext1-ρ f m shape-fn tp) + (sign + (format "e1ρ~a~a~a~a" + f m shape-fn + (gs-expr tp )))] + [(tcomp-reshape s tp) + (sign + (format "r~a~a" + s (gs-expr tp ) ))] + [(tcomp-ds-ref index) + (sign + (format "dsr~a" index ))]))) + +;; Count references so that the tcomp AST nodes that refer to the same memory +;; location i.e. common AST nodes get extracted by let-binding them in the +;; compiled output racket code. +(define count-references + (λ (t) + (count-references-expr t (hasheq)))) + +(define count-references-expr + (λ (t counter) + (match t + ((tpromise tc _) + (count-references-expr tc counter)) + ((tcomp) (count-references-tcomp t counter)) + (_ counter)))) + +(define count-references-tcomp + (λ (tc counter) + (match-let (((counter-data tc-binding-name tc-ref-count) + (hash-ref counter tc + (λ () + (let-values (((st _) (struct-info tc))) + (let-values (((tcomp-name _0 _1 _2 _3 _4 _5 _6) + (struct-type-info st))) + (counter-data (gensym tcomp-name) 0))))))) + (let* ((new-count (add1 tc-ref-count)) + (counter^ (hash-set counter tc + (counter-data tc-binding-name + new-count)))) + (cond + ((> new-count 1) + ;; No need to increase reference count of children if parent occurs + ;; more than once. This helps avoid creating extra tcom-var later in + ;; ecs if child tcomp occurs only once within the parent tcomp, but + ;; the parent tcomp itself occurs more than once. + counter^) + (else + (match tc + [(tcomp-list->tensor lst) + (for/fold + ((counter^^ counter^)) + ((l lst)) + (count-references-expr l counter^^))] + [(tcomp-build-tensor s f) counter^] + [(tcomp-tref tp i) + (count-references-expr tp counter^)] + [(tcomp-trefs tp b) + (count-references-expr tp counter^)] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (count-references-expr + tp-z + (count-references-expr + tp-t1 + (count-references-expr tp-t0 counter^)))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (count-references-expr + zp + (count-references-expr tp counter^))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (count-references-expr + tp-u + (count-references-expr tp-t counter^))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (count-references-expr + tp-u + (count-references-expr tp-t counter^))] + [(tcomp-ext1-ρ-scalar f tp) + (count-references-expr tp counter^)] + [(tcomp-ext1-ρ f m shape-fn tp) + (count-references-expr tp counter^)] + [(tcomp-reshape s tp) + (count-references-expr tp counter^)] + [(tcomp-ds-ref index) counter^] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (count-references-expr + body + (count-references-expr rhs counter^))] + [(tcomp-var name) counter^]))))))) + +(define extract-common-subexpressions + (λ (t counter) + (let-values (((instrs bindings) + (run-compiler-ecs (ecs-expr t counter) '()))) + (for/fold ((body instrs)) + ((binding bindings)) + (tcomp-let (car binding) (cdr binding) body))))) + +(define ecs-expr + (λ (tc counter) + (match tc + [(tpromise tc s) + (->ecs + (ecs-expr tc counter) + (λ (instrs) + (inj-ecs-val (tpromise instrs s))))] + [tc #:when (number? tc) + (inj-ecs-val tc)] + [(tcomp) (ecs-tcomp tc counter)]))) + +(define ecs-tcomp + (λ (tc counter) + (let ((tc-counter-data + (hash-ref counter tc + (λ () + (counter-data (gensym 'illegal) 0))))) + (match tc + [(tcomp-list->tensor lst) + (let ((instrs-list-compiler + (for/foldr + ((list-compiler (inj-ecs-val '()))) + ((arg lst)) + (->ecs + (ecs-expr arg counter) + (λ (instrs) + (->ecs + list-compiler + (λ (instrs-list) + (inj-ecs-val (cons instrs instrs-list))))))))) + (->ecs + instrs-list-compiler + (λ (instrs-list) + (inj-ecs-tcomp (tcomp-list->tensor instrs-list) tc-counter-data))))] + [(tcomp-build-tensor s f) + (inj-ecs-tcomp tc tc-counter-data)] + [(tcomp-tref tp i) + (->ecs + (ecs-expr tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-tref instrs i) tc-counter-data)))] + [(tcomp-trefs tp b) + (->ecs + (ecs-expr tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->ecs + (ecs-expr tp-t0 counter) + (λ (t0-instrs) + (->ecs + (ecs-expr tp-t1 counter) + (λ (t1-instrs) + (->ecs + (ecs-expr tp-z counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-ext2-∇ fᵈ r0 r1 shape-fn + t0-instrs t1-instrs z-instrs + out0 out1 i) + tc-counter-data)))))))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (->ecs + (ecs-expr tp counter) + (λ (t-instrs) + (->ecs + (ecs-expr zp counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-ext1-∇ t-instrs z-instrs f m shape-fn) + tc-counter-data)))))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (->ecs + (ecs-expr tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-expr tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-ext2-ρ-scalar f t-instrs u-instrs) + tc-counter-data)))))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (->ecs + (ecs-expr tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-expr tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-ext2-ρ t-instrs u-instrs f m n shape-fn) + tc-counter-data)))))] + [(tcomp-ext1-ρ-scalar f tp) + (->ecs + (ecs-expr tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f instrs) tc-counter-data)))] + [(tcomp-ext1-ρ f m shape-fn tp) + (->ecs + (ecs-expr tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-ext1-ρ f m shape-fn instrs) tc-counter-data)))] + [(tcomp-reshape s tp) + (->ecs + (ecs-expr tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-reshape s instrs) tc-counter-data)))] + [(tcomp-ds-ref index) (inj-ecs-tcomp tc tc-counter-data)])))) + +(struct CompilerECS (run-compiler) #:transparent) + +(define run-compiler-ecs + (λ (c bindings) + ((CompilerECS-run-compiler c) bindings))) + +(define inj-ecs-val + (λ (v) + (CompilerECS (λ (bindings) (values v bindings))))) + +(define inj-ecs-tcomp + (λ (instrs cd) + (match-let (((counter-data binding-var ref-count) cd)) + (CompilerECS (λ (bindings) + (cond + ((<= ref-count 1) + (values instrs bindings)) + ((assv binding-var bindings) + (values (tcomp-var binding-var) bindings)) + (else + (values (tcomp-var binding-var) + (extend-bindings binding-var instrs bindings))))))))) + +(define ->ecs + (λ (c f) + (CompilerECS + (λ (bindings) + (let-values (((instrs bindings^) (run-compiler-ecs c bindings))) + (run-compiler-ecs (f instrs) bindings^)))))) + +(define extend-env + (λ (k v env) + `((,k . ,v) . ,env))) + +(define extend-bindings extend-env) + +(define exists-in-env? + (λ (ft env) + (match env + ('() #f) + (`((,k . ,v) . ,_) #:when (eq? ft v) k) + (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) + +(define generate-racket + (λ (t) + (gr-expr t))) + +(define gr-expr + (λ (t) + (match t + [(tpromise tc _) (gr-expr tc)] + [v #:when (number? v) v] + [(tcomp) (gr-tcomp t)]))) + +(define gr-tcomp + (λ (tc) + (match tc + [(tcomp-list->tensor lst) + (let ((instrs-list (map gr-expr lst))) + `(flat:list->tensor (list ,@instrs-list)))] + [(tcomp-build-tensor s f) + (flat:build-tensor s f)] + [(tcomp-tref tp i) + (let ((instrs (gr-expr tp)) + (i-instrs (gr-expr i))) + `(flat:tref ,instrs ,i-instrs))] + [(tcomp-trefs tp b) + (let ((instrs (gr-expr tp)) + (b-instrs (gr-expr b))) + `(rt:trefs ,instrs ,b-instrs))] + [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-expr tp-t0)) + (t1-instrs (gr-expr tp-t1)) + (z-instrs (gr-expr tp-z))) + `(let* ([b (if (zero? ,i) ,out0 ,out1)] + [v (ext2-∇-result-res b)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out0 ,out1) + (ext2-∇-result-res b)) + (else v))))] + [(tcomp-ext1-∇ tp zp f m shape-fn) + (let ((t-instrs (gr-expr tp)) + (z-instrs (gr-expr zp))) + `(scalarize + (flat-ext1-∇ ,f ,m ,shape-fn + (ensure-flat ,t-instrs) + (ensure-flat ,z-instrs))))] + [(tcomp-ext2-ρ-scalar f tp-t tp-u) + (let ((t-instrs (gr-expr tp-t)) + (u-instrs (gr-expr tp-u))) + `(,f ,t-instrs ,u-instrs))] + [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + (let ((t-instrs (gr-expr tp-t)) + (u-instrs (gr-expr tp-u))) + `(scalarize + (flat-ext2-ρ ,f ,m ,n ,shape-fn + (ensure-flat ,t-instrs) + (ensure-flat ,u-instrs))))] + [(tcomp-ext1-ρ-scalar f tp) + (let ((instrs (gr-expr tp))) + `(,f ,instrs))] + [(tcomp-ext1-ρ f m shape-fn tp) + (let ((instrs (gr-expr tp))) + `(scalarize + (flat-ext1-ρ ,f ,m ,shape-fn + (ensure-flat ,instrs))))] + [(tcomp-reshape s tp) + (let ((instrs (gr-expr tp))) + `(flat ',s + (flat-store ,instrs) + (flat-offset ,instrs)))] + [(tcomp-let lhs rhs body) + (let ((rhs-instrs (gr-expr rhs)) + (body-instrs (gr-expr body))) + `(let ((,lhs ,rhs-instrs)) + ,body-instrs))] + [(tcomp-var name) name] + [(tcomp-ds-ref index) `(data-segment-ref ,index)]))) + +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; +;; Composing Compiler Passes +;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; + +(define print-compiler? (make-parameter #f)) +(define display-compiler-trace + (λ (title value) + (when (or (equal? (print-compiler?) #t) + (and (list? (print-compiler?)) (member title (print-compiler?)))) + (printf "~a:~n" title) + (displayln "--------------") + (pretty-print value) + (displayln "")))) +(define compile-tensor + (let ((cache (make-hash))) + (λ (t) + (display-compiler-trace 'Source-Tensor t) + (let-values (((eds-instrs ds) (extract-data-segment t))) + (display-compiler-trace 'Extract-Data-Segment-data ds) + (display-compiler-trace 'Extract-Data-Segment-instructions eds-instrs) + (let ((signature (generate-signature eds-instrs))) + (display-compiler-trace 'Generate-Signature signature) + (cond + ;; TODO: Uncomment this to reenable caching + (#f #;(hash-has-key? cache signature) + (let ((compiled (hash-ref cache signature))) + (display-compiler-trace 'Cache-Hit compiled) + (values compiled ds))) + (else + (let ((counter (count-references eds-instrs))) + (display-compiler-trace 'Count-References counter) + (let ((extracted (extract-common-subexpressions eds-instrs counter))) + (display-compiler-trace 'Extract-Common-Subexpressions extracted) + (let ((rkt (generate-racket extracted))) + (display-compiler-trace 'Generate-Racket rkt) + (hash-set! cache signature rkt) + (values rkt ds))))))))))) + +;;TODO: update this for new compiler passes +(define compile-tensor/checks + (λ (t) + (let-values (((eds-instrs ds) (extract-data-segment t))) + (flat:check-tensor-equal? (interp-tensor t) (interp-tensor eds-instrs)) + (let ((counter (count-references t))) + (let ((extracted (extract-common-subexpressions t counter))) + (flat:check-tensor-equal? (interp-tensor t) (interp-tensor extracted)) + (for/list ((cd (hash-values (count-references extracted)))) + (check-equal? (counter-data-ref-count cd) 1)) + (let-values (((rkt env) (generate-racket extracted))) + (flat:check-tensor-equal? (interp-tensor extracted) + (interp-racket rkt env)) + (values rkt env))))))) + +(define get-compiled + (λ (t) + (let-values (((instrs env) + (compile-tensor t))) + (make-instrs instrs env)))) + +(include "test/test-c3-compiler.rkt") +(provide get-compiled compile-tensor compile-tensor/checks print-compiler?) diff --git a/lazy/tensors/test/test-0-lazy.rkt b/lazy/tensors/test/test-0-lazy.rkt index cbbbde1..5592dbe 100644 --- a/lazy/tensors/test/test-0-lazy.rkt +++ b/lazy/tensors/test/test-0-lazy.rkt @@ -1,420 +1,7 @@ (module+ test (require rackunit) - ;; TODO: Add a comment above each test case describing what the test case is testing - (define-check (check-compiler-invariants tp) - (let-values (((instrs locals env) (run-compiler - (compile-expr tp - (count-references tp (hasheq))) - '() '()))) - (with-check-info - (('env (nested-info - (map (λ (name/flat) - (make-check-info (car name/flat) (cdr name/flat))) - env))) - ('instrs instrs)) - (for ((name/flat env)) - (unless (and (flat:flat? (cdr name/flat)) - (not (null? (flat-shape (cdr name/flat))))) - (fail-check (format (string-append "Value associated with the variable" - " ~a should be a flat tensor. " - "Associated value found: ~a") - (car name/flat) (cdr name/flat))))) - (define unique-flats (list->seteq (map cdr env))) - (unless (equal? (set-count unique-flats) - (length (filter flat? (map cdr env)))) - (fail-check (string-append "Duplicate flat tensors found" - " in environment. Variables in environment" - " should be paired with unique" - " flat tensors")))))) - - (define test-lt (tensor 1 2 3)) - (check-compiler-invariants test-lt) - (check-true (flat? (tpromise-tensor test-lt))) - (check-equal? (flat-store (force/eval test-lt)) (vector 1 2 3)) - (check-true (flat? (tpromise-tensor test-lt))) - (check-exn exn:fail? (λ () (tensor test-lt 4))) - (check-exn exn:fail? (λ () (tensor 4 test-lt))) - - (define test-tcomp-tref (tp-tref test-lt 2)) - (check-compiler-invariants test-tcomp-tref) - (check-equal? (force/eval test-tcomp-tref) 3) - (check-exn exn:fail? (λ () (tp-tref test-lt 5))) - - (define test-nested-lt (tensor (tensor 1 2 3) (tensor 4 5 6))) - (define test-tcomp-tref-nested (tp-tref (tp-tref test-nested-lt 0) 2)) - (check-compiler-invariants test-tcomp-tref-nested) - (check-equal? (force/eval test-tcomp-tref-nested) 3) - (check-exn exn:fail? (λ () (tp-tref (tp-tref test-nested-lt 2) 0))) - (check-exn exn:fail? (λ () (tp-tref test-nested-lt 2))) - (check-exn exn:fail? (λ () (tensor test-nested-lt test-nested-lt test-lt))) - - (check-equal? (tp-tlen test-lt) 3) - (check-equal? (tp-tlen test-nested-lt) 2) - - (define test-lt-from-list (list->tpromise '(5 6 7 8))) - (check-compiler-invariants test-lt-from-list) - (check-equal? (flat-store (force/eval test-lt-from-list)) (vector 5 6 7 8)) - (define test-nested-lt-from-list - (list->tpromise `(,test-lt ,test-lt ,test-lt))) - (check-compiler-invariants test-nested-lt-from-list) - (check-equal? (flat-store (force/eval test-nested-lt-from-list)) - (vector 1 2 3 1 2 3 1 2 3)) - (check-equal? (tpromise-shape test-nested-lt-from-list) '(3 3)) - - (check-true (bounded-idx*? test-nested-lt-from-list (list 0 1))) - (check-false (bounded-idx*? test-nested-lt-from-list (list 1 3))) - (check-false (bounded-idx*? test-nested-lt-from-list (list 1 1 0))) - - (define test-tcomp-partial-eval - (begin - (force/eval test-nested-lt-from-list) - (force/eval test-nested-lt) - (force/eval test-lt) - (tp-tref - (tp-tref (tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9)) - test-nested-lt-from-list - (list->tpromise (list (tp-tref test-nested-lt 0) - (tp-tref test-nested-lt 1) - test-lt))) - 1) - 2))) - (check-compiler-invariants test-tcomp-partial-eval) - (flat:check-tensor-equal? (force/eval test-tcomp-partial-eval) - (force/eval (tensor 1 2 3))) - - (define test-build-shape '(4 3)) - (define test-built-tensor (build-tpromise test-build-shape - (λ (i) - (let ([row (car i)] - [column (cadr i)]) - (+ (* (sub1 (car test-build-shape)) - row) - column))))) - (check-compiler-invariants test-built-tensor) - (check-equal? (tpromise-shape test-built-tensor) test-build-shape) - (check-true (tcomp? (tpromise-tensor test-built-tensor))) - (flat:check-tensor-equal? (force/eval test-built-tensor) - (force/eval (tensor (tensor 0 1 2) - (tensor 3 4 5) - (tensor 6 7 8) - (tensor 9 10 11)))) - - (define test-refs '(0 2)) - (define test-tp-trefs (tp-trefs test-built-tensor test-refs)) - (check-compiler-invariants test-tp-trefs) - (check-true (tcomp? (tpromise-tensor test-tp-trefs))) - (check-equal? (tpromise-shape test-tp-trefs) - (flat-shape (force/eval test-tp-trefs))) - (flat:check-tensor-equal? (force/eval test-tp-trefs) - (force/eval (tensor (tensor 0 1 2) - (tensor 6 7 8)))) - (check-exn exn:fail? (λ () (tp-trefs test-nested-lt '(0 4)))) - - (define test-tp-reshape (tp-reshape '(3 2 1) (tp-trefs test-built-tensor '(1 3)))) - (check-compiler-invariants test-tp-reshape) - (flat:check-tensor-equal? (force/eval test-tp-reshape) - (force/eval (tensor (tensor (tensor 3) - (tensor 4)) - (tensor (tensor 5) - (tensor 9)) - (tensor (tensor 10) - (tensor 11))))) - (check-exn exn:fail? (λ () (tp-reshape '(4 5) test-tp-reshape))) - - (define sum-f - (λ (in-v iᵢ sᵢ out-v iₒ sₒ) - (vset! out-v iₒ - (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) - (+ sum (vref in-v i)))))) - - (define sum (tp-ext1-ρ sum-f 1)) - (define test-tp-sum (sum test-nested-lt)) - (check-compiler-invariants test-tp-sum) - (flat:check-tensor-equal? (force/eval test-tp-sum) - (force/eval (tensor 6.0 15.0))) - - (define test-tp-sum-nested (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) - (check-compiler-invariants test-tp-sum-nested) - (flat:check-tensor-equal? (force/eval test-tp-sum-nested) - (force/eval (tensor 4.0 6.0 5.0))) - - (define id-f (lambda (v) v)) - (define id-ρ (tp-ext1-ρ id-f 1 (λ (s) s))) - (define test-tp-id (id-ρ test-nested-lt)) - (check-compiler-invariants test-tp-id) - (flat:check-tensor-equal? (force/eval test-tp-id) - (force/eval (tensor (tensor 1 2 3) - (tensor 4 5 6)))) - - (define test-tp-id-scalar (id-ρ (sum (tensor 4 5 6)))) - (check-compiler-invariants test-tp-id-scalar) - (check-equal? (force/eval test-tp-id-scalar) 15.0) - - (define t0 - (build-tpromise '(2 3 4) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) - (define *-ρ (tp-ext2-ρ * 0 0)) - (define t0sqr (*-ρ t0 t0)) - - (check-compiler-invariants t0sqr) - (flat:check-tensor-equal? (force/eval t0sqr) - (flat:reshape - '(2 3 4) - (flat:tensor - 0 4 16 36 - 64 100 144 196 - 256 324 400 484 - 576 676 784 900 - 1024 1156 1296 1444 - 1600 1764 1936 2116))) - - (define *-2-1-f - (λ (v0 i0 s0 v1 i1 s1 vout iout sout) - (for ([j0 (in-range 0 s0)]) - (vset! vout (+ iout j0) - (* (vref v0 (+ i0 j0)) - (vref v1 (+ i1 (modulo j0 s1)))))))) - - (define t1 - (build-tpromise '(5 6) - (λ (i) - (match-define `(,x ,y) i) - (* 2.0 (+ (* x 6) y))))) - - (define t2 - (build-tpromise '(6) - (λ (i) (* 3.0 (car i))))) - - (define *-2-1 - (tp-ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) - - (define r-1-2 - (*-2-1 t1 t2)) - - (check-equal? (tpromise-shape r-1-2) '(5 6)) - (check-compiler-invariants r-1-2) - (flat:check-tensor-equal? (force/eval r-1-2) - (flat:reshape - '(5 6) - (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0))) - - (define t3 - (build-tpromise '(3 5 6) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) - - (define t4 - (build-tpromise '(3 6) - (λ (i) - (match-define `(,x ,y) i) - (* 3.0 (+ (* x 6) y))))) - - (define r-3-4 - (*-2-1 t3 t4)) - - (check-equal? (tpromise-shape r-3-4) '(3 5 6)) - (check-compiler-invariants r-3-4) - (flat:check-tensor-equal? (force/eval r-3-4) - (flat:reshape - '(3 5 6) - (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0 - - 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 - 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 - 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 - 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 - 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 - - 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 - 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 - 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 - 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) - - (define r-sum-2-scalar (*-ρ (sum t2) (sum (tensor 2 3 4)))) - (check-compiler-invariants r-sum-2-scalar) - (flat:check-tensor-equal? (force/eval r-sum-2-scalar) 405.0) - - (define r1-td (tensor 3.0 4.0 5.0)) - (define r2-td (tp-reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) - - (define +ᶠ +) - (define +ᵈ (λ (a b z) (values z z))) - - (define sqrᶠ (λ (a) (* a a))) - (define sqrᵈ - (λ (a z) (* z 2 a))) - - (define d-sqr (tp-ext1-∇ sqrᵈ 0 scalar-shape)) - - (define one-like - (λ (t) - (build-tpromise (tpromise-shape t) (λ (_) 1.0)))) - - (define tcomp-dsqr-r1 (d-sqr r1-td (one-like r1-td))) - (check-compiler-invariants tcomp-dsqr-r1) - (flat:check-tensor-equal? (force/eval tcomp-dsqr-r1) - (flat:tensor 6.0 8.0 10.0)) - - (let ((gsqr (d-sqr r2-td (one-like r2-td)))) - (check-compiler-invariants gsqr) - (flat:check-tensor-equal? (force/eval gsqr) - (flat:reshape - '(2 3) - (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) - - (define d+ (tp-ext2-∇ +ᵈ 0 0 scalar-shape)) - - (let-values (((da db) (d+ 2.0 3.0 1.0))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (force/eval da) 1.0) - (check-compiler-invariants db) - (flat:check-tensor-equal? (force/eval db) 1.0)) - - (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (force/eval da) - (flat:tensor 1.0 1.0 1.0)) - (check-compiler-invariants db) - (flat:check-tensor-equal? (force/eval db) - (flat:tensor 1.0 1.0 1.0))) - - (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (force/eval da) - (flat:tensor 2.0 2.0 2.0)) - (check-compiler-invariants db) - (flat:check-tensor-equal? (force/eval db) - (flat:reshape - '(2 3) - (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) - - (define *∇ (tp-ext2-∇ (λ (a b z) (values (* z b) (* z a))) - 0 - 0)) - - (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) - (tensor 1.0 2.0 3.0) - (tensor 1.0 1.0 1.0)))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 2.0 3.0))) - (flat:check-tensor-equal? (force/eval gu) (force/eval (tensor 2.0 3.0 4.0)))) - - (define sum-1-∇ - (λ (g t it st vz iz sz) - (for* ([i (in-range it (+ it st))]) - (vset! g i (vref vz iz))))) - - (define sum-∇ (tp-ext1-∇ sum-1-∇ 1 (λ (s) '()))) - - (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) - 1.0))) - (check-compiler-invariants gt) - (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 1.0 1.0)))) - - (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) - (tensor 2.0 3.0 4.0)) - (tensor 2.0 1.0)))) - (check-compiler-invariants gt) - (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0))))) - ;; t and u must have the same shape - (define s2-f (lambda (t u) (tensor (sum t) (sum u)))) - (define s2-d - (λ (g0 g1 t it st u iu su vz iz sz) - (for* ([i (in-range it (+ it st))]) - (vset! g0 i (vref vz iz)) - (vset! g1 i (vref vz (+ iz 1)))))) - (define s2-∇ (tp-ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)))) - (let-values (((gt gu) (s2-∇ (tensor 2.0 3.0 4.0) - (tensor 1.0 2.0 3.0) - (tensor 1.0 1.0)))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? (force/eval gt) (force/eval (tensor 1.0 1.0 1.0))) - (flat:check-tensor-equal? (force/eval gu) (force/eval (tensor 1.0 1.0 1.0)))) - (let-values (((gt gu) (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) - (tensor 3.0 4.0 6.0)) - (tensor (tensor 5.0 6.0 6.0) - (tensor 7.0 8.0 6.0)) - (tensor (tensor 8.0 7.0 6.0) - (tensor 6.0 5.0 6.0))) - (tensor (tensor (tensor 6.0 8.0 6.0) - (tensor 3.0 4.0 6.0)) - (tensor (tensor 9.0 7.0 6.0) - (tensor 8.0 2.0 6.0)) - (tensor (tensor 9.0 7.0 6.0) - (tensor 5.0 1.0 6.0))) - (tensor (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)) - (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)) - (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)))))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? (force/eval gt) (force/eval (tp-reshape '(3 2 3) (list->tpromise (make-list 18 1.0))))) - (flat:check-tensor-equal? (force/eval gu) (force/eval (tp-reshape '(3 2 3) (list->tpromise (make-list 18 1.0)))))) - - (define test-env-flat-scalar ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) (list (tensor 1.0) 3.0))) - (check-compiler-invariants test-env-flat-scalar) - - ;; Check common subexpression introduced by let is not repeated - ;; TODO: add a generic version of the next 2 tests in check-compiler-invariants - (define count-flat:tref - (λ (ls) - (cond - ((null? ls) 0) - ((pair? (car ls)) (+ (count-flat:tref (car ls)) - (count-flat:tref (cdr ls)))) - ((eqv? (car ls) 'flat:tref) - (add1 (count-flat:tref (cdr ls)))) - (else (count-flat:tref (cdr ls)))))) - (define test-common-subexpr - (let ((t (tp-tref (tensor 1 2 3) 0))) - (tensor t t))) - (let-values (((instrs locals env) (run-compiler - (compile-expr test-common-subexpr - (count-references - test-common-subexpr - (hasheq))) - '() '()))) - (check-equal? (count-flat:tref (make-instrs instrs locals env)) 1 - "Common subexpression containing flat:tref should occur once")) - (define test-common-nested-subexprs - (let ((t1 (tp-tref (tensor (tensor 1 2 3) (tensor 4 5 6)) 0))) - (let ((t0 (tp-tref t1 0))) - (tensor t0 t0)))) - (let-values (((instrs locals env) (run-compiler - (compile-expr test-common-nested-subexprs - (count-references - test-common-nested-subexprs - (hasheq))) - '() '()))) - (check-equal? (count-flat:tref (make-instrs instrs locals env)) 2 - "Common subexpressions containing flat:tref should occur twice")) - - (define random-tensor - (λ (s) - (build-tpromise s (λ (tidx) (random 10))))) - (define test-build-random - (let ((v (random-tensor '(3 2 4)))) - (*-ρ v v))) - (check-pred - (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) - (vector->list (flat:flat-store (force/eval test-build-random))))) + (define test-nested-tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9))) + (check-true (bounded-idx*? test-nested-tensor (list 0 1))) + (check-false (bounded-idx*? test-nested-tensor (list 1 3))) + (check-false (bounded-idx*? test-nested-tensor (list 1 1 0)))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt new file mode 100644 index 0000000..6346084 --- /dev/null +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -0,0 +1,280 @@ +(module+ test + (require rackunit) + (require (only-in "c3-compiler.rkt" + compile-tensor/checks)) + (require "0-lazy.rkt") + (require "B-test-programs.rkt") + ;; TODO: Add a comment above each test case describing what the test case is testing + (define-check (check-compiler-invariants tp) + (let-values (((instrs ds) (compile-tensor tp))) + (with-check-info + (('data-segment ds) + ('instrs instrs)) + 'ok + #; + (for ((name/flat ds)) + (unless (and (flat:flat? (cdr name/flat)) + (not (null? (flat:flat-shape (cdr name/flat))))) + (fail-check (format (string-append "Value associated with the variable" + " ~a should be a flat tensor. " + "Associated value found: ~a") + (car name/flat) (cdr name/flat))))) + #; + (define unique-flats (list->seteq (map cdr ds))) + #; + (unless (equal? (set-count unique-flats) + (length (filter flat? (map cdr ds)))) + (fail-check (string-append "Duplicate flat tensors found" + " in data segment. Variables in data segment" + " should be paired with unique" + " flat tensors")))))) + ;;TODO: Move all check-compiler-invariant checks to the test file for + ;;c3-compiler.rkt file. + + ;;TODO: Refactor all test cases to use get-test-program so that ↓ doesn't + ;;mutate the programs defined in B-test-programs + (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) + (check-compiler-invariants test-tensor-r1-0) + (check-true (flat:flat? (tpromise-tensor test-tensor-r1-0))) + (flat:check-tensor-equal? (↓ test-tensor-r1-0) + (get-test-eval-res 'tensor-r1-0)) + (check-true (flat:flat? (tpromise-tensor test-tensor-r1-0))) + (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) + (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) + + (check-compiler-invariants test-tcomp-tref) + (check-equal? (↓ test-tcomp-tref) 3) + (check-exn exn:fail? (λ () (tref test-tensor-r1-0 5))) + + (define test-nested-tensor (get-test-program 'tensor-r2-0)) + (check-compiler-invariants test-tcomp-tref-nested) + (check-equal? (↓ test-tcomp-tref-nested) 3) + (check-exn exn:fail? (λ () (tref (tref test-nested-tensor 2) 0))) + (check-exn exn:fail? (λ () (tref test-nested-tensor 2))) + (check-exn exn:fail? (λ () (tensor test-nested-tensor test-nested-tensor test-tensor-r1-0))) + + (check-equal? (tlen test-tensor-r1-0) 3) + (check-equal? (tlen test-nested-tensor) 2) + + (check-compiler-invariants test-list->tensor) + (check-equal? (flat:flat-store (↓ test-list->tensor)) (vector 5 6 7 8)) + (check-compiler-invariants test-nested-list->tensor) + (check-equal? (flat:flat-store (↓ test-nested-list->tensor)) + (vector 1 2 3 1 2 3 1 2 3)) + (check-equal? (tpromise-shape test-nested-list->tensor) '(3 3)) + + (define test-tcomp-partial-eval + (begin + (↓ test-nested-list->tensor) + (↓ test-nested-tensor) + (↓ test-tensor-r1-0) + (tref + (tref (tensor (tensor (tensor 1 2 3) (tensor 4 5 6) (tensor 7 8 9)) + test-nested-list->tensor + (list->tensor (list (tref test-nested-tensor 0) + (tref test-nested-tensor 1) + test-tensor-r1-0))) + 1) + 2))) + (check-compiler-invariants test-tcomp-partial-eval) + (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) + (↓ (tensor 1 2 3))) + + (check-compiler-invariants test-built-tensor) + (check-equal? (tpromise-shape test-built-tensor) test-build-shape) + (check-true (tcomp? (tpromise-tensor test-built-tensor))) + (flat:check-tensor-equal? (↓ test-built-tensor) + (↓ (tensor (tensor 0 1 2) + (tensor 3 4 5) + (tensor 6 7 8) + (tensor 9 10 11)))) + + (check-compiler-invariants test-trefs) + (check-true (tcomp? (tpromise-tensor test-trefs))) + (check-equal? (tpromise-shape test-trefs) + (flat:flat-shape (↓ test-trefs))) + (flat:check-tensor-equal? (↓ test-trefs) + (↓ (tensor (tensor 0 1 2) + (tensor 6 7 8)))) + (check-exn exn:fail? (λ () (trefs test-nested-tensor '(0 4)))) + + (check-compiler-invariants test-reshape) + (flat:check-tensor-equal? (↓ test-reshape) + (↓ (tensor (tensor (tensor 3) + (tensor 4)) + (tensor (tensor 5) + (tensor 9)) + (tensor (tensor 10) + (tensor 11))))) + (check-exn exn:fail? (λ () (reshape '(4 5) test-reshape))) + + (check-compiler-invariants test-tp-sum) + (flat:check-tensor-equal? (↓ test-tp-sum) + (↓ (tensor 6.0 15.0))) + + (check-compiler-invariants test-tp-sum-nested) + (flat:check-tensor-equal? (↓ test-tp-sum-nested) + (↓ (tensor 4.0 6.0 5.0))) + + (check-compiler-invariants test-tp-id) + (flat:check-tensor-equal? (↓ test-tp-id) + (↓ (tensor (tensor 1 2 3) + (tensor 4 5 6)))) + + (check-compiler-invariants test-tp-id-scalar) + (check-equal? (↓ test-tp-id-scalar) 15.0) + + (check-compiler-invariants t0sqr) + (flat:check-tensor-equal? (↓ t0sqr) + (flat:reshape + '(2 3 4) + (flat:tensor + 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116))) + + (check-equal? (tpromise-shape r-1-2) '(5 6)) + (check-compiler-invariants r-1-2) + (flat:check-tensor-equal? (↓ r-1-2) + (flat:reshape + '(5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0))) + + (check-equal? (tpromise-shape r-3-4) '(3 5 6)) + (check-compiler-invariants r-3-4) + (flat:check-tensor-equal? (↓ r-3-4) + (flat:reshape + '(3 5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) + + (check-compiler-invariants r-sum-2-scalar) + (flat:check-tensor-equal? (↓ r-sum-2-scalar) 405.0) + + (check-compiler-invariants tcomp-dsqr-r1) + (flat:check-tensor-equal? (↓ tcomp-dsqr-r1) + (flat:tensor 6.0 8.0 10.0)) + + (let ((gsqr (d-sqr r2-td (one-like r2-td)))) + (check-compiler-invariants gsqr) + (flat:check-tensor-equal? (↓ gsqr) + (flat:reshape + '(2 3) + (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + + (let-values (((da db) (d+ 2.0 3.0 1.0))) + (check-compiler-invariants da) + (flat:check-tensor-equal? (↓ da) 1.0) + (check-compiler-invariants db) + (flat:check-tensor-equal? (↓ db) 1.0)) + + (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) + (check-compiler-invariants da) + (flat:check-tensor-equal? (↓ da) + (flat:tensor 1.0 1.0 1.0)) + (check-compiler-invariants db) + (flat:check-tensor-equal? (↓ db) + (flat:tensor 1.0 1.0 1.0))) + + (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) + (check-compiler-invariants da) + (flat:check-tensor-equal? (↓ da) + (flat:tensor 2.0 2.0 2.0)) + (check-compiler-invariants db) + (flat:check-tensor-equal? (↓ db) + (flat:reshape + '(2 3) + (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + + (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0 1.0)))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) + (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 2.0 3.0))) + (flat:check-tensor-equal? (↓ gu) (↓ (tensor 2.0 3.0 4.0)))) + + (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) + 1.0))) + (check-compiler-invariants gt) + (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 1.0 1.0)))) + + (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0)))) + (check-compiler-invariants gt) + (flat:check-tensor-equal? (↓ gt) (↓ (tensor (tensor 2.0 2.0 2.0) + (tensor 1.0 1.0 1.0))))) + (let-values (((gt gu) (s2-∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0)))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) + (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 1.0 1.0))) + (flat:check-tensor-equal? (↓ gu) (↓ (tensor 1.0 1.0 1.0)))) + (let-values (((gt gu) (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 5.0 6.0 6.0) + (tensor 7.0 8.0 6.0)) + (tensor (tensor 8.0 7.0 6.0) + (tensor 6.0 5.0 6.0))) + (tensor (tensor (tensor 6.0 8.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 8.0 2.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 5.0 1.0 6.0))) + (tensor (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)))))) + (check-compiler-invariants gt) + (check-compiler-invariants gu) + (flat:check-tensor-equal? + (↓ gt) + (↓ (reshape '(3 2 3) (list->tensor (make-list 18 1.0))))) + (flat:check-tensor-equal? + (↓ gu) + (↓ (reshape '(3 2 3) (list->tensor (make-list 18 1.0)))))) + + (check-compiler-invariants test-env-flat-scalar) + (flat:check-tensor-equal? (↓ test-env-flat-scalar) + (flat:tensor 3.0)) + + (check-compiler-invariants test-common-subexpr) + (flat:check-tensor-equal? (↓ test-common-subexpr) + (flat:tensor 1.0 1.0)) + + (check-compiler-invariants test-common-nested-subexprs) + (flat:check-tensor-equal? (↓ test-common-nested-subexprs) + (flat:tensor 1.0 1.0)) + + (check-pred + (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) + (vector->list (flat:flat-store (↓ test-build-random))))) diff --git a/lazy/tensors/test/test-A-equality.rkt b/lazy/tensors/test/test-A-equality.rkt index 9c87732..7d5e04d 100644 --- a/lazy/tensors/test/test-A-equality.rkt +++ b/lazy/tensors/test/test-A-equality.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require "0-lazy.rkt") (define t0 (reshape '(2 3 4) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt new file mode 100644 index 0000000..e8e4c94 --- /dev/null +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -0,0 +1,73 @@ +(module+ test + (require rackunit) + (require "B-test-programs.rkt") + (require "0-lazy.rkt") + + (define-check (check-signatures-equal? t1 t2) + (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) + ((eds-instrs-2 ds2) (extract-data-segment t2))) + (let ((sig1 (generate-signature eds-instrs-1)) + (sig2 (generate-signature eds-instrs-2))) + (with-check-info + (('extracted-instrs-1 eds-instrs-1) + ('extracted-instrs-2 eds-instrs-2) + ('data-segment-1 ds1) + ('data-segment-2 ds2) + ('signature-1 sig1) + ('signature-2 sig2)) + (unless (equal? sig1 sig2) + (fail-check "signature mismatch")))))) + + (define-check (check-signatures-not-equal? t1 t2) + (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) + ((eds-instrs-2 ds2) (extract-data-segment t2))) + (let ((sig1 (generate-signature eds-instrs-1)) + (sig2 (generate-signature eds-instrs-2))) + (with-check-info + (('extracted-instrs-1 eds-instrs-1) + ('extracted-instrs-2 eds-instrs-2) + ('data-segment-1 ds1) + ('data-segment-2 ds2) + ('signature-1 sig1) + ('signature-2 sig2)) + (when (equal? sig1 sig2) + (fail-check "signatures musn't match")))))) + + (define test-tensor-r1-1 (get-test-program 'tensor-r1-1)) + (check-signatures-equal? test-tcomp-tref + (make-tref-test-program test-tensor-r1-1)) + (check-signatures-not-equal? test-tcomp-tref + (make-list->tensor-test-program `(,test-tensor-r1-1))) + + (define tensor-r1 (get-test-program 'tensor-r1-0)) + (check-signatures-equal? (*-ρ 2 tensor-r1) (*-ρ 3 tensor-r1)) + (check-signatures-not-equal? (*-ρ 2 3) (*-ρ 3 3)) + + (define v^ (random-tensor (list 10 4))) + (define r^ (random-tensor (list 10 4 2))) + (check-signatures-equal? mean-v (mean v^)) + (check-signatures-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (tensor (tensor 12 23 44) + (tensor 23 46 57)))) + (check-signatures-not-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (tensor (tensor 12 23 44) + (tensor 23 46 57) + (tensor 67 32 58)))) + (check-signatures-not-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (reshape '(2 3) (tensor 1 2 3 4 5 6)))) + (check-signatures-equal? variance-v (variance v^)) + (check-signatures-equal? mean-r (mean r^)) + (check-signatures-equal? variance-r (variance r^)) + (check-signatures-not-equal? mean-v mean-r) + (check-signatures-not-equal? mean-v variance-v) + (check-signatures-not-equal? variance-v mean-r) + (check-signatures-equal? (+-ρ mean-v (tensor 0 1 2 3 4 5 6 7 8 9)) + (+-ρ (mean v^) (tensor 0 1 2 3 4 5 6 7 8 9))) + + (let ((a 2) + (b 3)) + ;;TODO: Fix these test cases + (let*-values (((da- db-) (d- a b 1.0)) + ((da+ db+) (d+ a b 1.0))) + (check-signatures-not-equal? da- da+) + (check-signatures-not-equal? db- db+)))) diff --git a/malted/test/test-O-init.rkt b/malted/test/test-O-init.rkt index 95c18dc..0316e13 100644 --- a/malted/test/test-O-init.rkt +++ b/malted/test/test-O-init.rkt @@ -1,28 +1,24 @@ (module+ test (require rackunit) - ;; TODO: Make this better. We musn't break abstraction boundaries - (require "../lazy/tensors/0-lazy.rkt") + (require (only-in "../base.rkt" ρ)) - (define v (init-shape (list 10 4))) + (define v (init-shape (list 1000 4))) (define mean-v - (abs (/ (sum (sum v)) 40))) + (abs (/ (sum (sum v)) 4000))) (define variance-v - (- (/ (sum (sum (* v v))) 40) (* mean-v mean-v))) - (check-true (< (force/eval mean-v) 0.05)) - (pretty-print (get-compiled variance-v)) - (check-true (let ((forced (force/eval variance-v))) + (- (/ (sum (sum (* v v))) 4000) (* mean-v mean-v))) + (check-true (< (ρ mean-v) 0.05)) + (check-true (let ((forced (ρ variance-v))) (and (>= forced 0.4) (<= forced 0.6)))) ;; Here variance will be 2/8 = 0.25 - (define r (init-shape (list 10 4 2))) - (define mean-r (abs (/ (sum (sum (sum r))) 80))) - (define variance-r (- (/ (sum (sum (sum (* r r)))) 80) + (define r (init-shape (list 1000 4 2))) + (define mean-r (abs (/ (sum (sum (sum r))) 8000))) + (define variance-r (- (/ (sum (sum (sum (* r r)))) 8000) (* mean-r mean-r))) - (check-true (< (force/eval mean-r) 0.05)) - (pretty-print (get-compiled variance-r)) - (check-true (let ((forced (force/eval variance-r))) - (println forced) + (check-true (< (ρ mean-r) 0.05)) + (check-true (let ((forced (ρ variance-r))) (and (>= forced 0.22) (<= forced 0.28))))) From c0520fa27ac6bade401827170a4ef69bbc480805 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 16 Nov 2023 15:50:46 -0500 Subject: [PATCH 65/83] [add-lazy]Fix some caching issues and implement optimizations --- lazy.rkt | 2 +- lazy/autodiff/A-autodiff.rkt | 3 +- lazy/autodiff/B-prims.rkt | 66 ++- lazy/autodiff/D-test-helpers.rkt | 4 - lazy/ext-ops/C-star-2-1.rkt | 2 +- lazy/ext-ops/D-sum.rkt | 8 +- lazy/ext-ops/E-argmax.rkt | 4 +- lazy/ext-ops/F-max.rkt | 4 +- lazy/ext-ops/G-correlate.rkt | 5 +- lazy/ext-ops/K-concat.rkt | 2 +- lazy/ext-ops/test/test-A-scalar-ops.rkt | 5 +- lazy/ext-ops/test/test-G-correlate.rkt | 4 +- lazy/tensors.rkt | 2 +- lazy/tensors/0-lazy.rkt | 67 +-- lazy/tensors/1-reflect.rkt | 7 +- lazy/tensors/B-test-programs.rkt | 101 +++-- lazy/tensors/c0-ast.rkt | 14 +- lazy/tensors/c1-racket-runtime.rkt | 12 +- lazy/tensors/c2-interpreter.rkt | 29 +- lazy/tensors/c3-compiler.rkt | 517 +++++++++++------------- lazy/tensors/test/test-1-reflect.rkt | 9 +- lazy/tensors/test/test-c3-compiler.rkt | 41 +- 22 files changed, 463 insertions(+), 445 deletions(-) diff --git a/lazy.rkt b/lazy.rkt index 87664db..7b5bba9 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -14,7 +14,7 @@ ext1-ρ ext2-ρ ext1-∇ ext2-∇ - print-compiler? + print-compiler? compiler-cache dual dual? ρ κ ∇ ∇¹ (rename-out (∇ gradient-of)) map* diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index e7399a1..e76b165 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -61,7 +61,8 @@ (define ∇ (λ (f theta) (let ((wrt (map* dual* theta))) - (∇-once (f wrt) wrt)))) + ;; TODO: try forcing (f wrt) to see if it fixes caching issues + (∇-once (f wrt) #;(↓ (f wrt)) wrt)))) (define ∇¹ (λ (f) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index ee39e10..07c1fb4 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -3,24 +3,7 @@ (require "../tensors.rkt") (require "A-autodiff.ss") -(define ρ-function - (λ (f) (f ρ-function))) - -(define ∇-function - (λ (f) (f ∇-function))) - -(define shape-fn - (λ (f) (f shape-fn))) - -;;TODO: make prim1 and prim2 func-callable structures using the prop:procedure -;;struct property - -(struct prim (ρ-fn ∇-fn shape-fn - signature ;;autogenerate this before runtime to avoid - ;;changing this during runtime - proc ;; This will be the prim*-dual func - prealloc? ;; use this to redefine expects-preallocated? - ) +(struct prim (ρ-fn ∇-fn shape-fn signature expects-prealloc? proc) #:property prop:procedure (λ (this . args) (apply (prim-proc this) args))) @@ -30,13 +13,11 @@ ;;prims (define prim1 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) - (λ (daf) - (cond - ((eq? daf ρ-function) ρ-fn) - ((eq? daf ∇-function) ∇-fn) - ((eq? daf shape-fn) shape) - (else (prim1-dual ρ-fn ∇-fn daf)))))) + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (symbol->string (gensym 'prim1)))) + (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (λ (da) + (prim1-dual ρ-fn ∇-fn da)))))) (define prim1-dual (λ (ρ-fn ∇-fn da) @@ -48,14 +29,11 @@ ((κ da) da ga σ)))))))) (define prim2 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)]) - (λ ds - (let ((daf (ref ds 0))) - (cond - ((eq? daf ρ-function) ρ-fn) - ((eq? daf ∇-function) ∇-fn) - ((eq? daf shape-fn) shape) - (else (prim2-dual ρ-fn ∇-fn daf (ref ds 1)))))))) + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (symbol->string (gensym 'prim2)))) + (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (λ (da db) + (prim2-dual ρ-fn ∇-fn da db)))))) (define prim2-dual (λ (ρ-fn ∇-fn da db) @@ -73,16 +51,26 @@ ;;---------------------------- (define ext1 (λ (f n) + (unless (prim? f) + (error 'ext1-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim1 - (ext1-ρ (ρ-function f) n (shape-fn f)) - (ext1-∇ (∇-function f) n (shape-fn f)) - (shape-fn f)))) + (ext1-ρ (prim-ρ-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (prim-signature f)) + (ext1-∇ (prim-∇-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (prim-signature f)) + (prim-shape-fn f) + #f))) (define ext2 (λ (f m n) + (unless (prim? f) + (error 'ext2-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim2 - (ext2-ρ (ρ-function f) m n (shape-fn f)) - (ext2-∇ (∇-function f) m n (shape-fn f)) - (shape-fn f)))) + (ext2-ρ (prim-ρ-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (prim-signature f)) + (ext2-∇ (prim-∇-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (prim-signature f)) + (prim-shape-fn f) + #f))) (provide prim1 prim2 ext1 ext2) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index 9da4331..1b1df9f 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -5,10 +5,6 @@ (require rackunit) -(define forced-ρ - (λ (d) - (↓ (ρ d)))) - (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) (let* ((y (↓ (apply fn args))) diff --git a/lazy/ext-ops/C-star-2-1.rkt b/lazy/ext-ops/C-star-2-1.rkt index 75ecaee..629b1ba 100644 --- a/lazy/ext-ops/C-star-2-1.rkt +++ b/lazy/ext-ops/C-star-2-1.rkt @@ -30,7 +30,7 @@ s)) (define *-2-1 - (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape)) + (prim2 *-2-1-base-ρ *-2-1-base-∇ *-2-1-shape #t)) (define d*-2-1 (ext2 *-2-1 2 1)) diff --git a/lazy/ext-ops/D-sum.rkt b/lazy/ext-ops/D-sum.rkt index 046131f..2fb6c83 100644 --- a/lazy/ext-ops/D-sum.rkt +++ b/lazy/ext-ops/D-sum.rkt @@ -23,13 +23,13 @@ (refr st 1))) (define sum-1 - (prim1 sum-1-ρ sum-1-∇ sum-shape)) + (prim1 sum-1-ρ sum-1-∇ sum-shape #t)) (define d-sum (ext1 sum-1 1)) (define sum-ρ - (ext1-ρ sum-1-ρ 1 sum-shape)) + (ext1-ρ sum-1-ρ 1 sum-shape #t)) (provide d-sum sum-ρ) @@ -54,13 +54,13 @@ (refr s 1))) (define sum-cols-2 - (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape)) + (prim1 sum-cols-2-ρ sum-cols-2-∇ sum-cols-shape #t)) (define d-sum-cols (ext1 sum-cols-2 2)) (define sum-cols-ρ - (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape)) + (ext1-ρ sum-cols-2-ρ 2 sum-cols-shape #t)) (include "test/test-D-sum.rkt") diff --git a/lazy/ext-ops/E-argmax.rkt b/lazy/ext-ops/E-argmax.rkt index ad4aefe..f18148b 100644 --- a/lazy/ext-ops/E-argmax.rkt +++ b/lazy/ext-ops/E-argmax.rkt @@ -27,13 +27,13 @@ '())) (define argmax-1 - (prim1 argmax-1-ρ argmax-1-∇ argmax-shape)) + (prim1 argmax-1-ρ argmax-1-∇ argmax-shape #t)) (define d-argmax (ext1 argmax-1 1)) (define argmax-ρ - (ext1-ρ argmax-1-ρ 1 argmax-shape)) + (ext1-ρ argmax-1-ρ 1 argmax-shape #t)) (include "test/test-E-argmax.rkt") diff --git a/lazy/ext-ops/F-max.rkt b/lazy/ext-ops/F-max.rkt index 5a54bd5..a5e70b6 100644 --- a/lazy/ext-ops/F-max.rkt +++ b/lazy/ext-ops/F-max.rkt @@ -35,13 +35,13 @@ (cdr st))) (define max-1 - (prim1 max-1-ρ max-1-∇ max-shape)) + (prim1 max-1-ρ max-1-∇ max-shape #t)) (define d-max (ext1 max-1 1)) (define max-ρ - (ext1-ρ max-1-ρ 1 max-shape)) + (ext1-ρ max-1-ρ 1 max-shape #t)) (include "test/test-F-max.rkt") diff --git a/lazy/ext-ops/G-correlate.rkt b/lazy/ext-ops/G-correlate.rkt index 4cbe0e5..cfe156b 100644 --- a/lazy/ext-ops/G-correlate.rkt +++ b/lazy/ext-ops/G-correlate.rkt @@ -58,7 +58,8 @@ (prim2 (correlate-3-1-ρ nd md qd) (correlate-3-1-∇ nd md qd) - correlate-shape))) + correlate-shape + #t))) (define d-correlate (λ (bank signal) @@ -82,7 +83,7 @@ (q (/ (- m 1) 2)) ;; This is the padding. (qd (* q d)) (md (* m d))) - ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape) + ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape #t) bank signal)))) (define last diff --git a/lazy/ext-ops/K-concat.rkt b/lazy/ext-ops/K-concat.rkt index 1f3ae9a..63f01b2 100644 --- a/lazy/ext-ops/K-concat.rkt +++ b/lazy/ext-ops/K-concat.rkt @@ -36,7 +36,7 @@ (vector-ref vz (+ iz i))))))))) (define concat-base - (prim2 concat-base-ρ concat-base-∇ concat-shape)) + (prim2 concat-base-ρ concat-base-∇ concat-shape #t)) (define d-concat-n (λ (n) diff --git a/lazy/ext-ops/test/test-A-scalar-ops.rkt b/lazy/ext-ops/test/test-A-scalar-ops.rkt index 4a5f87e..7997b7a 100644 --- a/lazy/ext-ops/test/test-A-scalar-ops.rkt +++ b/lazy/ext-ops/test/test-A-scalar-ops.rkt @@ -1,13 +1,12 @@ (module+ test (require rackunit) - (require (only-in "../tensors.rkt" tensor print-compiler?)) + (require (only-in "../tensors.rkt" tensor)) ;; Check basic numericals (let ((a 2) (b 3)) (check-ρ-∇ (d+ a b) 5 (list 1.0 1.0)) - (parameterize ((print-compiler? '(Cache-Hit))) - (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0))) + (check-ρ-∇ (d- a b) -1 (list 1.0 -1.0)) (check-ρ-∇ (d* a b) 6 (list 3.0 2.0)) (check-ρ-∇ (d/ a b) 2/3 diff --git a/lazy/ext-ops/test/test-G-correlate.rkt b/lazy/ext-ops/test/test-G-correlate.rkt index 417723c..e53a2fb 100644 --- a/lazy/ext-ops/test/test-G-correlate.rkt +++ b/lazy/ext-ops/test/test-G-correlate.rkt @@ -40,10 +40,10 @@ (tensor 23 24)))) (define corr-ρ - (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape)) + (ext2-ρ (correlate-3-1-ρ 12 6 2) 3 1 correlate-shape #t)) (define corr-∇ - (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape)) + (ext2-∇ (correlate-3-1-∇ 12 6 2) 3 1 correlate-shape #t)) (check-tensor-equal? (corr-ρ bank signal) ;; Should be of size nb diff --git a/lazy/tensors.rkt b/lazy/tensors.rkt index 21db9f5..330a31a 100644 --- a/lazy/tensors.rkt +++ b/lazy/tensors.rkt @@ -14,7 +14,7 @@ (provide ↓ scalarize) -(provide print-compiler?) +(provide print-compiler? compiler-cache) ;; These will get overriden by duals (provide tensor?) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index e823ff7..cc0d047 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -64,6 +64,8 @@ instructions refering to the same gensym variable (λ (lst) (cond [(andmap number? lst) (apply flat:tensor lst)] + [(andmap (λ (v) (and (tpromise? v) (flat:flat? (tpromise-tensor v)))) lst) + (apply flat:tensor (map tpromise-tensor lst))] [else (tcomp-list->tensor lst)]))) #; @@ -148,7 +150,7 @@ instructions refering to the same gensym variable (define build-tpromise (λ (s f) - (tpromise (tcomp-build-tensor s f) s))) + (tpromise (flat:build-tensor s f) s))) (define tp-trefs (λ (tp b) @@ -164,19 +166,26 @@ instructions refering to the same gensym variable `(,(length b) . ,(cdr (tpromise-shape tp))))]))) +;; Default arguments shape-fn and expects-prealloc? need not be passed when f is +;; a function on scalars and doesn't expect a preallocated output vector as its +;; argument. The signature argument is only supposed to be passed within the +;; definition of ext1 and ext2 functions in B-prims.rkt. (define tp-ext1-ρ - (λ (f m [shape-fn scalar-shape]) + (λ (f m + [shape-fn scalar-shape] + [expects-prealloc? #f] + [signature (format "~a" f)]) (λ (tp) (cond [(scalar? tp) (f tp)] [(and (tpromise? tp) (null? (tpromise-shape tp))) (tpromise - (tcomp-ext1-ρ-scalar f tp) + (tcomp-ext1-ρ-scalar f signature tp) '())] - [(flat:expects-preallocated? f) + [expects-prealloc? (tpromise - (tcomp-ext1-ρ f m shape-fn tp) + (tcomp-ext1-ρ f signature m shape-fn tp) (merge-shapes (tp-shape tp) m @@ -188,15 +197,19 @@ instructions refering to the same gensym variable (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) (tpromise - (tcomp-ext1-ρ flat-f m shape-fn tp) + (tcomp-ext1-ρ flat-f signature m shape-fn tp) (merge-shapes (tp-shape tp) m (shape-fn (min-shape m (tp-shape tp))))))])))) +;; See comment for tp-ext1-ρ (define tp-ext2-ρ - (λ (f m n [shape-fn scalar-shape]) + (λ (f m n + [shape-fn scalar-shape] + [expects-prealloc? #f] + [signature (format "~a" f)]) (λ (tp-t tp-u) (cond ((and (number? tp-t) (number? tp-u)) @@ -204,8 +217,8 @@ instructions refering to the same gensym variable [(and (tpromise? tp-t) (tpromise? tp-u) (null? (tpromise-shape tp-t)) (null? (tpromise-shape tp-u))) - (tpromise (tcomp-ext2-ρ-scalar f tp-t tp-u) '())] - [(flat:expects-preallocated? f) + (tpromise (tcomp-ext2-ρ-scalar f signature tp-t tp-u) '())] + [expects-prealloc? (let* ((s0 (tp-shape tp-t)) (s1 (tp-shape tp-u)) (sf0 (min-shape m s0)) @@ -214,7 +227,7 @@ instructions refering to the same gensym variable (tpromise (tcomp-ext2-ρ (ensure-tpromise tp-t) (ensure-tpromise tp-u) - f m n shape-fn) + f signature m n shape-fn) (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))] [else @@ -234,21 +247,25 @@ instructions refering to the same gensym variable (tpromise (tcomp-ext2-ρ (ensure-tpromise tp-t) (ensure-tpromise tp-u) - flat-f m n shape-fn) + flat-f signature m n shape-fn) (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) (define scalar-shape (λ (s0 [s1 '()]) '())) +;; See comment for tp-ext1-ρ (define tp-ext1-∇ - (λ (f m [shape-fn scalar-shape]) + (λ (f m + [shape-fn scalar-shape] + [expects-prealloc? #f] + [signature (format "~a" f)]) (λ (tp zp) (cond ((number? tp) (f tp zp)) - ((flat:expects-preallocated? f) + (expects-prealloc? (tpromise - (tcomp-ext1-∇ tp (ensure-tpromise zp) f m shape-fn) + (tcomp-ext1-∇ tp (ensure-tpromise zp) f signature m shape-fn) (tp-shape tp))) (else (let* ((in-shape (tpromise-shape tp)) @@ -256,18 +273,22 @@ instructions refering to the same gensym variable (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) (tpromise - (tcomp-ext1-∇ tp (ensure-tpromise zp) flat-f m shape-fn) + (tcomp-ext1-∇ tp (ensure-tpromise zp) flat-f signature m shape-fn) (tp-shape tp)))))))) +;; See comment for tp-ext1-ρ (define tp-ext2-∇ - (λ (f m n [shape-fn scalar-shape]) + (λ (f m n + [shape-fn scalar-shape] + [expects-prealloc? #f] + [signature (format "~a" f)]) (let ((tp-f (λ (f tp-t tp-u tp-z) - (tp-d-ext2^ f m n shape-fn + (tp-d-ext2^ f signature m n shape-fn tp-t tp-u tp-z)))) (λ (tp-t tp-u tp-z) (cond - ((flat:expects-preallocated? f) + (expects-prealloc? (tp-f f (ensure-tpromise tp-t) (ensure-tpromise tp-u) @@ -284,13 +305,13 @@ instructions refering to the same gensym variable (define tp-d-ext2^ - (λ (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z) - (let* ((out0 (ext2-∇-result 'uncalculated)) - (out1 (ext2-∇-result 'uncalculated))) + (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (let* ((out0 'uncalculated) + (out1 'uncalculated)) (values - (tpromise (tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 0) + (tpromise (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 0) (tp-shape tp-t0)) - (tpromise (tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 1) + (tpromise (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 1) (tp-shape tp-t1)))))) (define ensure-tpromise diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt index 313c33c..9dbae5f 100644 --- a/lazy/tensors/1-reflect.rkt +++ b/lazy/tensors/1-reflect.rkt @@ -3,6 +3,7 @@ (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (require "c0-ast.rkt") (require (only-in "c3-compiler.rkt" + compiler-cache print-compiler? get-compiled compile-tensor)) @@ -17,9 +18,9 @@ [(tpromise t _) #:when (or (flat:flat? t) (number? t) (tcomp? t)) - (let-values (((instrs env) + (let-values (((instrs data-segment) (compile-tensor t))) - (let ((res (interp-racket instrs env))) + (let ((res (interp-racket instrs data-segment))) (set-tpromise-tensor! tp res) res))] ;; NOTE: This case runs when we use tp-scalarize to turn @@ -51,6 +52,6 @@ (provide ↓ force*1 force*2) -(provide print-compiler? get-compiled +(provide print-compiler? compiler-cache get-compiled (rename-out (tp-scalarize scalarize))) diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt index c01e5b6..f75b2b1 100644 --- a/lazy/tensors/B-test-programs.rkt +++ b/lazy/tensors/B-test-programs.rkt @@ -24,7 +24,70 @@ 'tensor-r2-0 (test-program-data (λ () (tensor (tensor 1 2 3) (tensor 4 5 6))) - (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6))))) + (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6))) + 'build-tensor-r2-0 (test-program-data + (λ () + (build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))) + (flat:build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))) + 'build-tensor-r3-0 (test-program-data + (λ () + (build-tensor '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + (flat:build-tensor '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + 'build-tensor-r3-1 (test-program-data + (λ () + (build-tensor '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + (flat:build-tensor '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + 'extract-ds-once-tref (test-program-data + (λ () + (let ((n (tref (get-test-program 'tensor-r1-0) 1))) + (+-ρ n n))) + 4) + 'extract-ds-once-trefs (test-program-data + (λ () + (let ((tp (trefs (get-test-program 'tensor-r1-0) '(0 2)))) + (+-ρ tp tp))) + (flat:tensor 2 6)) + 'built-tensor (test-program-data + (λ () + (build-tensor test-build-shape + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column))))) + (flat:tensor (flat:tensor 0 1 2) + (flat:tensor 3 4 5) + (flat:tensor 6 7 8) + (flat:tensor 9 10 11))) + 'multi-built-tensor (test-program-data + (λ () + (+-ρ (get-test-program 'build-tensor-r2-0) + (tref (get-test-program 'build-tensor-r3-1) 0))) + ((flat:ext2-ρ * 0 0) 2 (flat:build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y)))))) + )) + (define get-test-program (λ (name) ((test-program-data-prog-thunk (hash-ref test-programs name))))) @@ -40,16 +103,10 @@ ,(get-test-program 'tensor-r1-0) ,(get-test-program 'tensor-r1-0)))) (define test-build-shape '(4 3)) -(define test-built-tensor (build-tensor test-build-shape - (λ (i) - (let ([row (car i)] - [column (cadr i)]) - (+ (* (sub1 (car test-build-shape)) - row) - column))))) + (define test-refs '(0 2)) -(define test-trefs (trefs test-built-tensor test-refs)) -(define test-reshape (reshape '(3 2 1) (trefs test-built-tensor '(1 3)))) +(define test-trefs (trefs (get-test-program 'built-tensor) test-refs)) +(define test-reshape (reshape '(3 2 1) (trefs (get-test-program 'built-tensor) '(1 3)))) (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) @@ -57,7 +114,7 @@ (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) (+ sum (vref in-v i)))))) -(define sum (ext1-ρ sum-f 1)) +(define sum (ext1-ρ sum-f 1 (λ (s) '()) #t)) (define test-tp-sum (sum (get-test-program 'tensor-r2-0))) (define test-tp-sum-nested (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) @@ -67,10 +124,8 @@ (define test-tp-id-scalar (id-ρ (sum (tensor 4 5 6)))) (define t0 - (build-tensor '(2 3 4) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + (get-test-program 'build-tensor-r3-0)) + (define *-ρ (ext2-ρ * 0 0)) (define t0sqr (*-ρ t0 t0)) @@ -82,26 +137,20 @@ (vref v1 (+ i1 (modulo j0 s1)))))))) (define t1 - (build-tensor '(5 6) - (λ (i) - (match-define `(,x ,y) i) - (* 2.0 (+ (* x 6) y))))) + (get-test-program 'build-tensor-r2-0)) (define t2 (build-tensor '(6) (λ (i) (* 3.0 (car i))))) (define *-2-1 - (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0))) + (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0) #t)) (define r-1-2 (*-2-1 t1 t2)) (define t3 - (build-tensor '(3 5 6) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + (get-test-program 'build-tensor-r3-1)) (define t4 (build-tensor '(3 6) @@ -141,7 +190,7 @@ (for* ([i (in-range it (+ it st))]) (vset! g i (vref vz iz))))) -(define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()))) +(define sum-∇ (ext1-∇ sum-1-∇ 1 (λ (s) '()) #t)) ;; t and u must have the same shape (define s2-f (lambda (t u) (tensor (sum t) (sum u)))) @@ -150,7 +199,7 @@ (for* ([i (in-range it (+ it st))]) (vset! g0 i (vref vz iz)) (vset! g1 i (vref vz (+ iz 1)))))) -(define s2-∇ (ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)))) +(define s2-∇ (ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)) #t)) (define test-env-flat-scalar ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt index 1cf0c7c..b566691 100644 --- a/lazy/tensors/c0-ast.rkt +++ b/lazy/tensors/c0-ast.rkt @@ -10,7 +10,6 @@ (: s (Listof Natural)) ;; non-empty #; (: f (-> (Listof Natural) Number)) -(struct tcomp-build-tensor tcomp (s f) #:transparent) #; (: tp tpromise) #; @@ -26,12 +25,12 @@ (-> (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural)))) -(struct tcomp-ext1-ρ-scalar tcomp (f tp) #:transparent) -(struct tcomp-ext1-ρ tcomp (f m shape-fn tp) #:transparent) -(struct tcomp-ext2-ρ-scalar tcomp (f tp-t tp-u) #:transparent) -(struct tcomp-ext2-ρ tcomp (tp-t tp-u f m n shape-fn) #:transparent) -(struct tcomp-ext1-∇ tcomp (tp zp f m shape-fn) #:transparent) -(struct tcomp-ext2-∇ tcomp (fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) +(struct tcomp-ext1-ρ-scalar tcomp (f sign tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f sign m shape-fn tp) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f sign tp-t tp-u) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f sign m n shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f sign m shape-fn) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) (struct tcomp-let tcomp (lhs rhs body) #:transparent) @@ -58,7 +57,6 @@ (provide (struct-out tcomp) (struct-out tcomp-list->tensor) - (struct-out tcomp-build-tensor) (struct-out tcomp-tref) (struct-out tcomp-trefs) (struct-out tcomp-ext1-ρ-scalar) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 62db3a8..68e9333 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -55,10 +55,8 @@ vz (+ offz iz) stride-z))) - (set-ext2-∇-result-res! out0 - (scalarize (flat s0 g0 0))) - (set-ext2-∇-result-res! out1 - (scalarize (flat s1 g1 0))))))))) + (data-segment-set! out0 (scalarize (flat s0 g0 0))) + (data-segment-set! out1 (scalarize (flat s1 g1 0))))))))) (define rt:trefs (λ (ft b) @@ -69,6 +67,10 @@ (define data-segment (make-parameter #f)) +(define data-segment-set! + (λ (i v) + (vector-set! (data-segment) i v))) + (define data-segment-ref (λ (i) (vector-ref (data-segment) i))) @@ -81,4 +83,4 @@ (provide runtime flat? flat:build-tensor flat:list->tensor flat:tref rt:trefs (struct-out ext2-∇-result) ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ - flat flat-store flat-offset flat-ext1-ρ) + flat flat-store flat-offset flat-ext1-ρ data-segment) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index 3c7939a..d1f465d 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -5,13 +5,7 @@ runtime flat? flat:build-tensor flat:list->tensor flat:tref rt:trefs ext2-∇-result-res ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store - flat-offset flat-ext1-ρ)) - -(define make-instrs - (λ (instrs ds) - `(begin - (data-segment ,ds) - ,instrs))) + flat-offset flat-ext1-ρ data-segment)) (define interp-tensor-tcomp (λ (tc env ds) @@ -21,15 +15,15 @@ (for/list ((arg lst)) (interp-tensor-expr arg env ds)))) (flat:list->tensor eval-list))] - [(tcomp-build-tensor s f) - (flat:build-tensor s f)] [(tcomp-tref tp i) (flat:tref (interp-tensor-expr tp env ds) (interp-tensor-expr i env ds))] [(tcomp-trefs tp b) (rt:trefs (interp-tensor-expr tp env ds) (interp-tensor-expr b env ds))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ _ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + ;; TODO: fix this case because we now use the data segment rather than + ;; ext2-∇-result for output (let* ([b (if (zero? i) out0 out1)] [v (ext2-∇-result-res b)]) (cond @@ -41,21 +35,21 @@ out0 out1) (ext2-∇-result-res b)) (else v)))] - [(tcomp-ext1-∇ tp zp f m shape-fn) + [(tcomp-ext1-∇ tp zp f _ m shape-fn) (scalarize (flat-ext1-∇ f m shape-fn (ensure-flat (interp-tensor-expr tp env ds)) (ensure-flat (interp-tensor-expr zp env ds))))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) + [(tcomp-ext2-ρ-scalar f _ tp-t tp-u) (f (interp-tensor-expr tp-t env ds) (interp-tensor-expr tp-u env ds))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f _ m n shape-fn) (scalarize (flat-ext2-ρ f m n shape-fn (ensure-flat (interp-tensor-expr tp-t env ds)) (ensure-flat (interp-tensor-expr tp-u env ds))))] - [(tcomp-ext1-ρ-scalar f tp) + [(tcomp-ext1-ρ-scalar f _ tp) (f (interp-tensor-expr tp env ds))] - [(tcomp-ext1-ρ f m shape-fn tp) + [(tcomp-ext1-ρ f _ m shape-fn tp) (scalarize (flat-ext1-ρ f m shape-fn (ensure-flat (interp-tensor-expr tp env ds))))] @@ -93,6 +87,7 @@ (define interp-racket (lambda (instrs ds) - (eval (make-instrs instrs ds) runtime))) + (parameterize ((data-segment ds)) + (eval instrs runtime)))) -(provide interp-racket interp-tensor make-instrs) +(provide interp-racket interp-tensor) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 6899619..4cfa685 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -1,7 +1,7 @@ #lang racket (require "c0-ast.rkt") -(require (only-in "c2-interpreter.rkt" make-instrs interp-tensor interp-racket)) +(require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (require rackunit) @@ -18,88 +18,118 @@ ;;TODO: later eds and gs passes should not be needed because the tcomp AST nodes ;;should have a signature and dss field which will be populated at the time of ;;their instantiation. Then we just access those fields from the AST node rather -;;than computing them. Use a global data segment that has flat tensors used by all tcomp nodes in our program. - -;;Extracts the data segment which is a vector that contains scalars (arguments -;;to tref), flat tensors and flat tensor¹ of indices that will be the arguments -;;to trefs. +;;than computing them. Use a global data segment that has flat tensors used by +;;all tcomp nodes in our program. + +;;Extracts the data segment which is a vector that contains +;; * scalars (arguments to tref) +;; * flat tensors +;; * flat tensor¹ of indices that will be the arguments to trefs +;; * the symbol 'uncalculated as an initial placeholder for the output of +;; tcomp-ext2-∇ which will be later replaced by the flat tensor output +;; +;; TODO: Remove all uses of build-refs because we longer have tcomp-build-tensor nodes (define extract-data-segment (λ (t) - (let-values (((t^ data-segment-stack) (eds-expr t '()))) + (let-values (((t^ data-segment-stack build-refs) (eds-expr t '() (hasheq)))) ;; convert data segment stack to data segment array - (values t^ (list->vector (reverse data-segment-stack)))))) + (values t^ + (list->vector (reverse data-segment-stack)) + build-refs)))) + +;; Checks if a member equivalent to v exists in dss using equiv? and based on +;; that returns the dss index where v was inserted and the new dss with +;; insertion as values +;; TODO: Reconsider performance impact of this function +(define insert-unless-exists + (λ (v dss equiv?) + (cond + ((member v dss equiv?) + => (λ (res/rest) + (values (length (cdr res/rest)) dss))) + (else (values (length dss) (cons v dss)))))) (define eds-expr - (λ (t dss) + (λ (t dss build-refs) (match t (s #:when (number? s) - (values s dss)) + (values s dss build-refs)) (ft #:when (flat? ft) - (cond - ((memq ft dss) - => - (λ (res) - (values (tcomp-ds-ref (length (cdr res))) dss))) - (else (values (tcomp-ds-ref (length dss)) (cons ft dss))))) + (let-values (((idx dss^) (insert-unless-exists ft dss eq?))) + (values (tcomp-ds-ref idx) dss^ build-refs))) ((tpromise tc s) - (let-values (((tc^ dss^) (eds-expr tc dss))) + (let-values (((tc^ dss^ build-refs^) (eds-expr tc dss build-refs))) (cond - ((number? tc^) (values tc^ dss^)) - (else (values (tpromise tc^ s) dss^))))) - ((tcomp) (eds-tcomp t dss))))) + ((number? tc^) (values tc^ dss^ build-refs^)) + (else (values (tpromise tc^ s) dss^ build-refs^))))) + ((tcomp) (eds-tcomp t dss build-refs))))) (define eds-tcomp - (λ (tc dss) + (λ (tc dss build-refs) (match tc [(tcomp-list->tensor lst) - (let-values (((ts dss^) + (let-values (((ts dss^ build-refs^) (for/fold ((ts '()) - (dss^ dss)) + (dss^ dss) + (build-refs^ build-refs)) ((l lst)) - (let-values (((t dss^^) (eds-expr l dss^))) - (values (cons t ts) dss^^))))) - (values (tcomp-list->tensor (reverse ts)) dss^))] - [(tcomp-build-tensor s f) - (values tc dss)] + (let-values (((t dss^^ build-refs^^) + (eds-expr l dss^ build-refs^))) + (values (cons t ts) dss^^ build-refs^^))))) + (values (tcomp-list->tensor (reverse ts)) dss^ build-refs^))] + [(tcomp-tref tp i) - (let-values (((t dss^) (eds-expr tp dss))) - (values (tcomp-tref t (tcomp-ds-ref (length dss^))) - (cons i dss^)))] - [(tcomp-trefs tp b) - (let-values (((t dss^) (eds-expr tp dss))) - (values (tcomp-trefs t (tcomp-ds-ref (length dss^))) - (cons (flat:list->tensor b) dss^)))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let-values (((t0 dss^) (eds-expr tp-t0 dss))) - (let-values (((t1 dss^^) (eds-expr tp-t1 dss^))) - (let-values (((z dss^^^) (eds-expr tp-z dss^^))) - (values (tcomp-ext2-∇ fᵈ r0 r1 shape-fn t0 t1 z out0 out1 i) - dss^^^))))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (let-values (((zp^ dss^^) (eds-expr zp dss^))) - (values (tcomp-ext1-∇ tp^ zp^ f m shape-fn) dss^^)))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (let-values (((t dss^) (eds-expr tp-t dss))) - (let-values (((u dss^^) (eds-expr tp-u dss^))) - (values (tcomp-ext2-ρ-scalar f t u) dss^^)))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (let-values (((t dss^) (eds-expr tp-t dss))) - (let-values (((u dss^^) (eds-expr tp-u dss^))) - (values (tcomp-ext2-ρ t u f m n shape-fn) dss^^)))] - [(tcomp-ext1-ρ-scalar f tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-ext1-ρ-scalar f tp^) dss^))] - [(tcomp-ext1-ρ f m shape-fn tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-ext1-ρ f m shape-fn tp^) dss^))] + (let-values (((t dss^ build-refs^) (eds-expr tp dss build-refs))) + (let-values (((idx dss^^) (insert-unless-exists i dss^ eqv?))) + (values (tcomp-tref t (tcomp-ds-ref idx)) dss^^ build-refs^)))] + [(tcomp-trefs tp i-list) + (let-values (((t dss^ build-refs^) (eds-expr tp dss build-refs))) + (let-values (((idx dss^^) + ;; Comparison by flat:tensor-equal? is okay because + ;; members of b are integers (not reals) and their + ;; equality is checked without a tolerance. + ;; TODO: Reconsider performance impact of flat:list->tensor. + ;; Maybe memoize it. + (insert-unless-exists (flat:list->tensor i-list) + dss^ + flat:tensor-equal?))) + (values (tcomp-trefs t (tcomp-ds-ref idx)) dss^^ build-refs^)))] + [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let-values (((t0 dss^ build-refs^) (eds-expr tp-t0 dss build-refs))) + (let-values (((t1 dss^^ build-refs^^) (eds-expr tp-t1 dss^ build-refs^))) + (let-values (((z dss^^^ build-refs^^^) (eds-expr tp-z dss^^ build-refs^^))) + (values (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn t0 t1 z + (length dss^^^) + (add1 (length dss^^^)) i) + (cons out1 (cons out0 dss^^^)) + build-refs^^^))))] + [(tcomp-ext1-∇ tp zp f signature m shape-fn) + (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) + (let-values (((zp^ dss^^ build-refs^^) (eds-expr zp dss^ build-refs^))) + (values (tcomp-ext1-∇ tp^ zp^ f signature m shape-fn) + dss^^ + build-refs^^)))] + [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) + (let-values (((t dss^ build-refs^) (eds-expr tp-t dss build-refs))) + (let-values (((u dss^^ build-refs^^) (eds-expr tp-u dss^ build-refs^))) + (values (tcomp-ext2-ρ-scalar f signature t u) dss^^ build-refs^^)))] + [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) + (let-values (((t dss^ build-refs^) (eds-expr tp-t dss build-refs))) + (let-values (((u dss^^ build-refs^^) (eds-expr tp-u dss^ build-refs^))) + (values (tcomp-ext2-ρ t u f signature m n shape-fn) dss^^ build-refs^^)))] + [(tcomp-ext1-ρ-scalar f signature tp) + (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) + (values (tcomp-ext1-ρ-scalar f signature tp^) dss^ build-refs^))] + [(tcomp-ext1-ρ f signature m shape-fn tp) + (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) + (values (tcomp-ext1-ρ f signature m shape-fn tp^) dss^ build-refs^))] [(tcomp-reshape s tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-reshape s tp^) dss^))]))) + (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) + (values (tcomp-reshape s tp^) dss^ build-refs^))]))) (define hash-signatures? - (make-parameter #t)) + (make-parameter #f)) ;;TODO: Optimize sign by replacing it with this commented function #; (define sign @@ -121,22 +151,10 @@ (format "~a" (xxh32-digest xxh32-ctx))) (else (format "~a" s)))))) -#; -(define generate-signature - (λ (t) - (gs-expr t '()))) (define generate-signature (λ (t) (gs-expr t))) -#; -(define gs-expr - (λ (t position) - (match t - (s #:when (number? s) - (sign (format "s~a~a" s position))) - ((tpromise tc _) (gs-expr tc (cons 0 position))) - ((tcomp) (gs-tcomp t position))))) (define gs-expr (λ (t) (match t @@ -145,80 +163,6 @@ ((tpromise tc _) (gs-expr tc)) ((tcomp) (gs-tcomp t))))) -#; -(define gs-tcomp - (λ (tc position) - (match tc - [(tcomp-list->tensor lst) - (let ((list-sig - (for/fold ((sig "")) - ((l lst) - (i (in-naturals 0))) - (string-append sig (gs-expr l (cons i position)))))) - (sign (format "l>t~a~a" list-sig position)))] - [(tcomp-build-tensor s f) - (sign (format "bt~a~a~a" s f position))] - [(tcomp-tref tp i) - (sign - (format "tr~a~a~a" - (gs-expr tp (cons 0 position)) - (gs-expr i (cons 1 position)) - position))] - [(tcomp-trefs tp b) - (sign - (format "trs~a~a~a" - (gs-expr tp (cons 0 position)) - (gs-expr b (cons 1 position)) - position))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (sign - (format "e2∇~a~a_~a~a~a~a~a~a~a" - fᵈ r0 r1 shape-fn - (gs-expr tp-t0 (cons 0 position)) - (gs-expr tp-t1 (cons 1 position)) - (gs-expr tp-z (cons 2 position)) - i - position))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (sign - (format "e1∇~a~a~a~a~a~a" - f m shape-fn - (gs-expr tp (cons 0 position)) - (gs-expr zp (cons 1 position)) - position))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (sign - (format "e2ρs~a~a~a~a" - f - (gs-expr tp-t (cons 0 position)) - (gs-expr tp-u (cons 1 position)) - position))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (sign - (format "e2ρ~a~a_~a~a~a~a~a" - f m n shape-fn - (gs-expr tp-t (cons 0 position)) - (gs-expr tp-u (cons 1 position)) - position))] - [(tcomp-ext1-ρ-scalar f tp) - (sign - (format "e1ρs~a~a~a" - f - (gs-expr tp (cons 0 position)) - position))] - [(tcomp-ext1-ρ f m shape-fn tp) - (sign - (format "e1ρ~a~a~a~a~a" - f m shape-fn - (gs-expr tp (cons 0 position)) - position))] - [(tcomp-reshape s tp) - (sign - (format "r~a~a~a" - s (gs-expr tp (cons 0 position)) position))] - [(tcomp-ds-ref index) - (sign - (format "dsr~a~a" index position))]))) (define gs-tcomp (λ (tc) (match tc @@ -229,47 +173,47 @@ (i (in-naturals 0))) (string-append sig (gs-expr l))))) (sign (format "l>t~a" list-sig)))] - [(tcomp-build-tensor s f) - (sign (format "bt~a~a" s f))] + [(tcomp-tref tp i) (sign (format "tr~a~a" (gs-expr tp ) (gs-expr i )))] [(tcomp-trefs tp b) (sign (format "trs~a~a" (gs-expr tp ) (gs-expr b )))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (sign - (format "e2∇~a~a_~a~a~a~a~a~a" - fᵈ r0 r1 shape-fn + (format "e2∇~a~a_~a~a~a~a~a~a~a~a" + signature r0 r1 shape-fn (gs-expr tp-t0 ) (gs-expr tp-t1 ) (gs-expr tp-z ) + out0 out1 i))] - [(tcomp-ext1-∇ tp zp f m shape-fn) + [(tcomp-ext1-∇ tp zp f signature m shape-fn) (sign (format "e1∇~a~a~a~a~a" - f m shape-fn + signature m shape-fn (gs-expr tp ) (gs-expr zp )))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) + [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) (sign (format "e2ρs~a~a~a" - f + signature (gs-expr tp-t ) (gs-expr tp-u )))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) (sign (format "e2ρ~a~a_~a~a~a~a" - f m n shape-fn + signature m n shape-fn (gs-expr tp-t ) (gs-expr tp-u )))] - [(tcomp-ext1-ρ-scalar f tp) + [(tcomp-ext1-ρ-scalar f signature tp) (sign - (format "e1ρs~a~a" f (gs-expr tp )))] - [(tcomp-ext1-ρ f m shape-fn tp) + (format "e1ρs~a~a" signature (gs-expr tp )))] + [(tcomp-ext1-ρ f signature m shape-fn tp) (sign (format "e1ρ~a~a~a~a" - f m shape-fn + signature m shape-fn (gs-expr tp )))] [(tcomp-reshape s tp) (sign @@ -284,79 +228,77 @@ ;; compiled output racket code. (define count-references (λ (t) - (count-references-expr t (hasheq)))) + (let-values (((counter uid) (count-references-expr t (hasheq) 0))) + counter))) (define count-references-expr - (λ (t counter) + (λ (t counter uid) (match t ((tpromise tc _) - (count-references-expr tc counter)) - ((tcomp) (count-references-tcomp t counter)) - (_ counter)))) + (count-references-expr tc counter uid)) + ((tcomp) (count-references-tcomp t counter uid)) + (_ (values counter uid))))) (define count-references-tcomp - (λ (tc counter) + (λ (tc counter uid) (match-let (((counter-data tc-binding-name tc-ref-count) (hash-ref counter tc (λ () (let-values (((st _) (struct-info tc))) (let-values (((tcomp-name _0 _1 _2 _3 _4 _5 _6) (struct-type-info st))) - (counter-data (gensym tcomp-name) 0))))))) + (counter-data (string->symbol + (format "~a_~a" tcomp-name uid)) + 0))))))) (let* ((new-count (add1 tc-ref-count)) (counter^ (hash-set counter tc (counter-data tc-binding-name - new-count)))) + new-count))) + (uid^ (add1 uid))) (cond ((> new-count 1) ;; No need to increase reference count of children if parent occurs ;; more than once. This helps avoid creating extra tcom-var later in ;; ecs if child tcomp occurs only once within the parent tcomp, but ;; the parent tcomp itself occurs more than once. - counter^) + (values counter^ uid^)) (else (match tc [(tcomp-list->tensor lst) (for/fold - ((counter^^ counter^)) + ((counter^^ counter^) + (uid^^ uid^)) ((l lst)) - (count-references-expr l counter^^))] - [(tcomp-build-tensor s f) counter^] + (count-references-expr l counter^^ uid^^))] [(tcomp-tref tp i) - (count-references-expr tp counter^)] + (count-references-expr tp counter^ uid^)] [(tcomp-trefs tp b) - (count-references-expr tp counter^)] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (count-references-expr - tp-z - (count-references-expr - tp-t1 - (count-references-expr tp-t0 counter^)))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (count-references-expr - zp - (count-references-expr tp counter^))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (count-references-expr - tp-u - (count-references-expr tp-t counter^))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (count-references-expr - tp-u - (count-references-expr tp-t counter^))] - [(tcomp-ext1-ρ-scalar f tp) - (count-references-expr tp counter^)] - [(tcomp-ext1-ρ f m shape-fn tp) - (count-references-expr tp counter^)] + (count-references-expr tp counter^ uid^)] + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (count-references-expr tp-t0 counter^ uid^)) + ((counter-2 uid-2) (count-references-expr tp-z counter-1 uid-1))) + (count-references-expr tp-t1 counter-2 uid-2))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let-values (((counter-1 uid-1) (count-references-expr tp counter^ uid^))) + (count-references-expr zp counter-1 uid-1))] + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let-values (((counter-1 uid-1) (count-references-expr tp-t counter^ uid^))) + (count-references-expr tp-u counter-1 uid-1))] + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let-values (((counter-1 uid-1) (count-references-expr tp-t counter^ uid^))) + (count-references-expr tp-u counter-1 uid-1))] + [(tcomp-ext1-ρ-scalar f sign tp) + (count-references-expr tp counter^ uid^)] + [(tcomp-ext1-ρ f sign m shape-fn tp) + (count-references-expr tp counter^ uid^)] [(tcomp-reshape s tp) - (count-references-expr tp counter^)] - [(tcomp-ds-ref index) counter^] + (count-references-expr tp counter^ uid^)] + [(tcomp-ds-ref index) (values counter^ uid^)] ;;need these cases for testing compiler invariant [(tcomp-let lhs rhs body) - (count-references-expr - body - (count-references-expr rhs counter^))] - [(tcomp-var name) counter^]))))))) + (let-values (((counter-1 uid-1) (count-references-expr rhs counter^ uid^))) + (count-references-expr body counter-1 uid-1))] + [(tcomp-var name) (values counter^ uid^)]))))))) (define extract-common-subexpressions (λ (t counter) @@ -401,8 +343,6 @@ instrs-list-compiler (λ (instrs-list) (inj-ecs-tcomp (tcomp-list->tensor instrs-list) tc-counter-data))))] - [(tcomp-build-tensor s f) - (inj-ecs-tcomp tc tc-counter-data)] [(tcomp-tref tp i) (->ecs (ecs-expr tp counter) @@ -413,7 +353,7 @@ (ecs-expr tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (->ecs (ecs-expr tp-t0 counter) (λ (t0-instrs) @@ -424,11 +364,11 @@ (ecs-expr tp-z counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext2-∇ fᵈ r0 r1 shape-fn + (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn t0-instrs t1-instrs z-instrs out0 out1 i) tc-counter-data)))))))] - [(tcomp-ext1-∇ tp zp f m shape-fn) + [(tcomp-ext1-∇ tp zp f sign m shape-fn) (->ecs (ecs-expr tp counter) (λ (t-instrs) @@ -436,9 +376,9 @@ (ecs-expr zp counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext1-∇ t-instrs z-instrs f m shape-fn) + (tcomp-ext1-∇ t-instrs z-instrs f sign m shape-fn) tc-counter-data)))))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) (->ecs (ecs-expr tp-t counter) (λ (t-instrs) @@ -446,9 +386,9 @@ (ecs-expr tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ-scalar f t-instrs u-instrs) + (tcomp-ext2-ρ-scalar f sign t-instrs u-instrs) tc-counter-data)))))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) (->ecs (ecs-expr tp-t counter) (λ (t-instrs) @@ -456,18 +396,18 @@ (ecs-expr tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ t-instrs u-instrs f m n shape-fn) + (tcomp-ext2-ρ t-instrs u-instrs f sign m n shape-fn) tc-counter-data)))))] - [(tcomp-ext1-ρ-scalar f tp) + [(tcomp-ext1-ρ-scalar f sign tp) (->ecs (ecs-expr tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f instrs) tc-counter-data)))] - [(tcomp-ext1-ρ f m shape-fn tp) + (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f sign instrs) tc-counter-data)))] + [(tcomp-ext1-ρ f sign m shape-fn tp) (->ecs (ecs-expr tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ f m shape-fn instrs) tc-counter-data)))] + (inj-ecs-tcomp (tcomp-ext1-ρ f sign m shape-fn instrs) tc-counter-data)))] [(tcomp-reshape s tp) (->ecs (ecs-expr tp counter) @@ -519,79 +459,78 @@ (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) (define generate-racket - (λ (t) - (gr-expr t))) + (λ (t build-refs) + (gr-expr t build-refs))) (define gr-expr - (λ (t) + (λ (t build-refs) (match t - [(tpromise tc _) (gr-expr tc)] + [(tpromise tc _) (gr-expr tc build-refs)] [v #:when (number? v) v] - [(tcomp) (gr-tcomp t)]))) + [(tcomp) (gr-tcomp t build-refs)]))) (define gr-tcomp - (λ (tc) + (λ (tc build-refs) (match tc [(tcomp-list->tensor lst) - (let ((instrs-list (map gr-expr lst))) + (let ((instrs-list (map (λ (t) (gr-expr t build-refs)) lst))) `(flat:list->tensor (list ,@instrs-list)))] - [(tcomp-build-tensor s f) - (flat:build-tensor s f)] [(tcomp-tref tp i) - (let ((instrs (gr-expr tp)) - (i-instrs (gr-expr i))) + (let ((instrs (gr-expr tp build-refs)) + (i-instrs (gr-expr i build-refs))) `(flat:tref ,instrs ,i-instrs))] [(tcomp-trefs tp b) - (let ((instrs (gr-expr tp)) - (b-instrs (gr-expr b))) + (let ((instrs (gr-expr tp build-refs)) + (b-instrs (gr-expr b build-refs))) `(rt:trefs ,instrs ,b-instrs))] - [(tcomp-ext2-∇ fᵈ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let ((t0-instrs (gr-expr tp-t0)) - (t1-instrs (gr-expr tp-t1)) - (z-instrs (gr-expr tp-z))) - `(let* ([b (if (zero? ,i) ,out0 ,out1)] - [v (ext2-∇-result-res b)]) - (cond - ((eqv? v 'uncalculated) - (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn - ,t0-instrs ,t1-instrs - ,z-instrs ,out0 ,out1) - (ext2-∇-result-res b)) - (else v))))] - [(tcomp-ext1-∇ tp zp f m shape-fn) - (let ((t-instrs (gr-expr tp)) - (z-instrs (gr-expr zp))) + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-expr tp-t0 build-refs)) + (t1-instrs (gr-expr tp-t1 build-refs)) + (z-instrs (gr-expr tp-z build-refs))) + (let ((b (if (zero? i) out0 out1))) + `(let* ([b ,b] + [v (data-segment-ref b)]) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out0 ,out1) + (data-segment-ref b)) + (else v)))))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let ((t-instrs (gr-expr tp build-refs)) + (z-instrs (gr-expr zp build-refs))) `(scalarize (flat-ext1-∇ ,f ,m ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,z-instrs))))] - [(tcomp-ext2-ρ-scalar f tp-t tp-u) - (let ((t-instrs (gr-expr tp-t)) - (u-instrs (gr-expr tp-u))) + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let ((t-instrs (gr-expr tp-t build-refs)) + (u-instrs (gr-expr tp-u build-refs))) `(,f ,t-instrs ,u-instrs))] - [(tcomp-ext2-ρ tp-t tp-u f m n shape-fn) - (let ((t-instrs (gr-expr tp-t)) - (u-instrs (gr-expr tp-u))) + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let ((t-instrs (gr-expr tp-t build-refs)) + (u-instrs (gr-expr tp-u build-refs))) `(scalarize (flat-ext2-ρ ,f ,m ,n ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,u-instrs))))] - [(tcomp-ext1-ρ-scalar f tp) - (let ((instrs (gr-expr tp))) + [(tcomp-ext1-ρ-scalar f sign tp) + (let ((instrs (gr-expr tp build-refs))) `(,f ,instrs))] - [(tcomp-ext1-ρ f m shape-fn tp) - (let ((instrs (gr-expr tp))) + [(tcomp-ext1-ρ f sign m shape-fn tp) + (let ((instrs (gr-expr tp build-refs))) `(scalarize (flat-ext1-ρ ,f ,m ,shape-fn (ensure-flat ,instrs))))] [(tcomp-reshape s tp) - (let ((instrs (gr-expr tp))) + (let ((instrs (gr-expr tp build-refs))) `(flat ',s (flat-store ,instrs) (flat-offset ,instrs)))] [(tcomp-let lhs rhs body) - (let ((rhs-instrs (gr-expr rhs)) - (body-instrs (gr-expr body))) + (let ((rhs-instrs (gr-expr rhs build-refs)) + (body-instrs (gr-expr body build-refs))) `(let ((,lhs ,rhs-instrs)) ,body-instrs))] [(tcomp-var name) name] @@ -610,42 +549,42 @@ (displayln "--------------") (pretty-print value) (displayln "")))) +(define cache + (make-parameter (make-hash))) (define compile-tensor - (let ((cache (make-hash))) - (λ (t) - (display-compiler-trace 'Source-Tensor t) - (let-values (((eds-instrs ds) (extract-data-segment t))) - (display-compiler-trace 'Extract-Data-Segment-data ds) - (display-compiler-trace 'Extract-Data-Segment-instructions eds-instrs) - (let ((signature (generate-signature eds-instrs))) - (display-compiler-trace 'Generate-Signature signature) - (cond - ;; TODO: Uncomment this to reenable caching - (#f #;(hash-has-key? cache signature) - (let ((compiled (hash-ref cache signature))) - (display-compiler-trace 'Cache-Hit compiled) - (values compiled ds))) - (else - (let ((counter (count-references eds-instrs))) - (display-compiler-trace 'Count-References counter) - (let ((extracted (extract-common-subexpressions eds-instrs counter))) - (display-compiler-trace 'Extract-Common-Subexpressions extracted) - (let ((rkt (generate-racket extracted))) - (display-compiler-trace 'Generate-Racket rkt) - (hash-set! cache signature rkt) - (values rkt ds))))))))))) + (λ (t) + (display-compiler-trace 'Source-Tensor t) + (let-values (((eds-instrs ds build-refs) (extract-data-segment t))) + (display-compiler-trace 'Extract-Data-Segment-data ds) + (display-compiler-trace 'Extract-Data-Segment-instructions eds-instrs) + (let ((signature (generate-signature eds-instrs))) + (display-compiler-trace 'Generate-Signature signature) + (cond + ((hash-has-key? (cache) signature) + (let ((compiled (hash-ref (cache) signature))) + (display-compiler-trace 'Cache-Hit signature) + (values compiled ds))) + (else + (let ((counter (count-references eds-instrs))) + (display-compiler-trace 'Count-References counter) + (let ((extracted (extract-common-subexpressions eds-instrs counter))) + (display-compiler-trace 'Extract-Common-Subexpressions extracted) + (let ((rkt (generate-racket extracted build-refs))) + (display-compiler-trace 'Generate-Racket rkt) + (hash-set! (cache) signature rkt) + (values rkt ds)))))))))) ;;TODO: update this for new compiler passes (define compile-tensor/checks (λ (t) - (let-values (((eds-instrs ds) (extract-data-segment t))) + (let-values (((eds-instrs ds build-refs) (extract-data-segment t))) (flat:check-tensor-equal? (interp-tensor t) (interp-tensor eds-instrs)) (let ((counter (count-references t))) (let ((extracted (extract-common-subexpressions t counter))) (flat:check-tensor-equal? (interp-tensor t) (interp-tensor extracted)) (for/list ((cd (hash-values (count-references extracted)))) (check-equal? (counter-data-ref-count cd) 1)) - (let-values (((rkt env) (generate-racket extracted))) + (let-values (((rkt env) (generate-racket extracted build-refs))) (flat:check-tensor-equal? (interp-tensor extracted) (interp-racket rkt env)) (values rkt env))))))) @@ -654,7 +593,9 @@ (λ (t) (let-values (((instrs env) (compile-tensor t))) - (make-instrs instrs env)))) + `(parameterize ((data-segment ,env)) + ,instrs)))) (include "test/test-c3-compiler.rkt") -(provide get-compiled compile-tensor compile-tensor/checks print-compiler?) +(provide get-compiled compile-tensor compile-tensor/checks print-compiler? + (rename-out (cache compiler-cache))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt index 6346084..1724f17 100644 --- a/lazy/tensors/test/test-1-reflect.rkt +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -80,9 +80,9 @@ (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) (↓ (tensor 1 2 3))) + (define test-built-tensor (get-test-program 'built-tensor)) (check-compiler-invariants test-built-tensor) (check-equal? (tpromise-shape test-built-tensor) test-build-shape) - (check-true (tcomp? (tpromise-tensor test-built-tensor))) (flat:check-tensor-equal? (↓ test-built-tensor) (↓ (tensor (tensor 0 1 2) (tensor 3 4 5) @@ -277,4 +277,9 @@ (check-pred (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) - (vector->list (flat:flat-store (↓ test-build-random))))) + (vector->list (flat:flat-store (↓ test-build-random))) + "Side-effect of generating random tensor must only be run once") + + (flat:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) + (get-test-eval-res 'multi-built-tensor)) +) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index e8e4c94..f83a20e 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -4,8 +4,8 @@ (require "0-lazy.rkt") (define-check (check-signatures-equal? t1 t2) - (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) - ((eds-instrs-2 ds2) (extract-data-segment t2))) + (let-values (((eds-instrs-1 ds1 build-refs1) (extract-data-segment t1)) + ((eds-instrs-2 ds2 build-refs2) (extract-data-segment t2))) (let ((sig1 (generate-signature eds-instrs-1)) (sig2 (generate-signature eds-instrs-2))) (with-check-info @@ -13,14 +13,16 @@ ('extracted-instrs-2 eds-instrs-2) ('data-segment-1 ds1) ('data-segment-2 ds2) + ('refs-for-build-tensor-nodes-1 build-refs1) + ('refs-for-build-tensor-nodes-2 build-refs2) ('signature-1 sig1) ('signature-2 sig2)) (unless (equal? sig1 sig2) (fail-check "signature mismatch")))))) (define-check (check-signatures-not-equal? t1 t2) - (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) - ((eds-instrs-2 ds2) (extract-data-segment t2))) + (let-values (((eds-instrs-1 ds1 build-refs1) (extract-data-segment t1)) + ((eds-instrs-2 ds2 build-refs2) (extract-data-segment t2))) (let ((sig1 (generate-signature eds-instrs-1)) (sig2 (generate-signature eds-instrs-2))) (with-check-info @@ -28,6 +30,8 @@ ('extracted-instrs-2 eds-instrs-2) ('data-segment-1 ds1) ('data-segment-2 ds2) + ('refs-for-build-tensor-nodes-1 build-refs1) + ('refs-for-build-tensor-nodes-2 build-refs2) ('signature-1 sig1) ('signature-2 sig2)) (when (equal? sig1 sig2) @@ -49,10 +53,10 @@ (check-signatures-equal? (mean (get-test-program 'tensor-r2-0)) (mean (tensor (tensor 12 23 44) (tensor 23 46 57)))) - (check-signatures-not-equal? (mean (get-test-program 'tensor-r2-0)) - (mean (tensor (tensor 12 23 44) - (tensor 23 46 57) - (tensor 67 32 58)))) + (check-signatures-equal? (mean (get-test-program 'tensor-r2-0)) + (mean (tensor (tensor 12 23 44) + (tensor 23 46 57) + (tensor 67 32 58)))) (check-signatures-not-equal? (mean (get-test-program 'tensor-r2-0)) (mean (reshape '(2 3) (tensor 1 2 3 4 5 6)))) (check-signatures-equal? variance-v (variance v^)) @@ -66,8 +70,25 @@ (let ((a 2) (b 3)) - ;;TODO: Fix these test cases (let*-values (((da- db-) (d- a b 1.0)) ((da+ db+) (d+ a b 1.0))) (check-signatures-not-equal? da- da+) - (check-signatures-not-equal? db- db+)))) + (check-signatures-not-equal? db- db+))) + + (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-tref)))) + (check-pred + (λ (ds) + (eqv? (vector-length ds) 2)) + ds + (string-append "Tensors and tref indices occurring multiple times in" + " source AST but referring to the same tensor AST node must" + " be added to the data segment only once."))) + (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-trefs)))) + (check-pred + (λ (ds) + (eqv? (vector-length ds) 2)) + ds + (string-append "Tensors and trefs index lists occurring multiple times in" + " source AST but pointing to the same tensor AST node must" + " be added to the data segment only once."))) + ) From e085e98054557e5de45d470113bb20e00722243b Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 16 Nov 2023 16:56:24 -0500 Subject: [PATCH 66/83] [add-lazy]Remove uses of build-refs parameter in compiler --- lazy/tensors/c3-compiler.rkt | 172 ++++++++++++------------- lazy/tensors/test/test-c3-compiler.rkt | 12 +- 2 files changed, 87 insertions(+), 97 deletions(-) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 4cfa685..38ae3a6 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -28,14 +28,11 @@ ;; * the symbol 'uncalculated as an initial placeholder for the output of ;; tcomp-ext2-∇ which will be later replaced by the flat tensor output ;; -;; TODO: Remove all uses of build-refs because we longer have tcomp-build-tensor nodes (define extract-data-segment (λ (t) - (let-values (((t^ data-segment-stack build-refs) (eds-expr t '() (hasheq)))) + (let-values (((t^ data-segment-stack) (eds-expr t '()))) ;; convert data segment stack to data segment array - (values t^ - (list->vector (reverse data-segment-stack)) - build-refs)))) + (values t^ (list->vector (reverse data-segment-stack)))))) ;; Checks if a member equivalent to v exists in dss using equiv? and based on ;; that returns the dss index where v was inserted and the new dss with @@ -50,41 +47,40 @@ (else (values (length dss) (cons v dss)))))) (define eds-expr - (λ (t dss build-refs) + (λ (t dss) (match t (s #:when (number? s) - (values s dss build-refs)) + (values s dss)) (ft #:when (flat? ft) (let-values (((idx dss^) (insert-unless-exists ft dss eq?))) - (values (tcomp-ds-ref idx) dss^ build-refs))) + (values (tcomp-ds-ref idx) dss^))) ((tpromise tc s) - (let-values (((tc^ dss^ build-refs^) (eds-expr tc dss build-refs))) + (let-values (((tc^ dss^) (eds-expr tc dss))) (cond - ((number? tc^) (values tc^ dss^ build-refs^)) - (else (values (tpromise tc^ s) dss^ build-refs^))))) - ((tcomp) (eds-tcomp t dss build-refs))))) + ((number? tc^) (values tc^ dss^)) + (else (values (tpromise tc^ s) dss^))))) + ((tcomp) (eds-tcomp t dss))))) (define eds-tcomp - (λ (tc dss build-refs) + (λ (tc dss) (match tc [(tcomp-list->tensor lst) - (let-values (((ts dss^ build-refs^) + (let-values (((ts dss^) (for/fold ((ts '()) - (dss^ dss) - (build-refs^ build-refs)) + (dss^ dss)) ((l lst)) - (let-values (((t dss^^ build-refs^^) - (eds-expr l dss^ build-refs^))) - (values (cons t ts) dss^^ build-refs^^))))) - (values (tcomp-list->tensor (reverse ts)) dss^ build-refs^))] + (let-values (((t dss^^) + (eds-expr l dss^))) + (values (cons t ts) dss^^))))) + (values (tcomp-list->tensor (reverse ts)) dss^))] [(tcomp-tref tp i) - (let-values (((t dss^ build-refs^) (eds-expr tp dss build-refs))) + (let-values (((t dss^) (eds-expr tp dss))) (let-values (((idx dss^^) (insert-unless-exists i dss^ eqv?))) - (values (tcomp-tref t (tcomp-ds-ref idx)) dss^^ build-refs^)))] + (values (tcomp-tref t (tcomp-ds-ref idx)) dss^^)))] [(tcomp-trefs tp i-list) - (let-values (((t dss^ build-refs^) (eds-expr tp dss build-refs))) + (let-values (((t dss^) (eds-expr tp dss))) (let-values (((idx dss^^) ;; Comparison by flat:tensor-equal? is okay because ;; members of b are integers (not reals) and their @@ -94,39 +90,37 @@ (insert-unless-exists (flat:list->tensor i-list) dss^ flat:tensor-equal?))) - (values (tcomp-trefs t (tcomp-ds-ref idx)) dss^^ build-refs^)))] + (values (tcomp-trefs t (tcomp-ds-ref idx)) dss^^)))] [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let-values (((t0 dss^ build-refs^) (eds-expr tp-t0 dss build-refs))) - (let-values (((t1 dss^^ build-refs^^) (eds-expr tp-t1 dss^ build-refs^))) - (let-values (((z dss^^^ build-refs^^^) (eds-expr tp-z dss^^ build-refs^^))) + (let-values (((t0 dss^) (eds-expr tp-t0 dss))) + (let-values (((t1 dss^^) (eds-expr tp-t1 dss^))) + (let-values (((z dss^^^) (eds-expr tp-z dss^^))) (values (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn t0 t1 z (length dss^^^) (add1 (length dss^^^)) i) - (cons out1 (cons out0 dss^^^)) - build-refs^^^))))] + (cons out1 (cons out0 dss^^^))))))] [(tcomp-ext1-∇ tp zp f signature m shape-fn) - (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) - (let-values (((zp^ dss^^ build-refs^^) (eds-expr zp dss^ build-refs^))) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (let-values (((zp^ dss^^) (eds-expr zp dss^))) (values (tcomp-ext1-∇ tp^ zp^ f signature m shape-fn) - dss^^ - build-refs^^)))] + dss^^)))] [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) - (let-values (((t dss^ build-refs^) (eds-expr tp-t dss build-refs))) - (let-values (((u dss^^ build-refs^^) (eds-expr tp-u dss^ build-refs^))) - (values (tcomp-ext2-ρ-scalar f signature t u) dss^^ build-refs^^)))] + (let-values (((t dss^) (eds-expr tp-t dss))) + (let-values (((u dss^^) (eds-expr tp-u dss^))) + (values (tcomp-ext2-ρ-scalar f signature t u) dss^^)))] [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) - (let-values (((t dss^ build-refs^) (eds-expr tp-t dss build-refs))) - (let-values (((u dss^^ build-refs^^) (eds-expr tp-u dss^ build-refs^))) - (values (tcomp-ext2-ρ t u f signature m n shape-fn) dss^^ build-refs^^)))] + (let-values (((t dss^) (eds-expr tp-t dss))) + (let-values (((u dss^^) (eds-expr tp-u dss^))) + (values (tcomp-ext2-ρ t u f signature m n shape-fn) dss^^)))] [(tcomp-ext1-ρ-scalar f signature tp) - (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) - (values (tcomp-ext1-ρ-scalar f signature tp^) dss^ build-refs^))] + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-ext1-ρ-scalar f signature tp^) dss^))] [(tcomp-ext1-ρ f signature m shape-fn tp) - (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) - (values (tcomp-ext1-ρ f signature m shape-fn tp^) dss^ build-refs^))] + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-ext1-ρ f signature m shape-fn tp^) dss^))] [(tcomp-reshape s tp) - (let-values (((tp^ dss^ build-refs^) (eds-expr tp dss build-refs))) - (values (tcomp-reshape s tp^) dss^ build-refs^))]))) + (let-values (((tp^ dss^) (eds-expr tp dss))) + (values (tcomp-reshape s tp^) dss^))]))) (define hash-signatures? (make-parameter #f)) @@ -176,52 +170,52 @@ [(tcomp-tref tp i) (sign - (format "tr~a~a" (gs-expr tp ) (gs-expr i )))] + (format "tr~a~a" (gs-expr tp) (gs-expr i)))] [(tcomp-trefs tp b) (sign - (format "trs~a~a" (gs-expr tp ) (gs-expr b )))] + (format "trs~a~a" (gs-expr tp) (gs-expr b)))] [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (sign (format "e2∇~a~a_~a~a~a~a~a~a~a~a" signature r0 r1 shape-fn - (gs-expr tp-t0 ) - (gs-expr tp-t1 ) - (gs-expr tp-z ) + (gs-expr tp-t0) + (gs-expr tp-t1) + (gs-expr tp-z) out0 out1 i))] [(tcomp-ext1-∇ tp zp f signature m shape-fn) (sign (format "e1∇~a~a~a~a~a" signature m shape-fn - (gs-expr tp ) - (gs-expr zp )))] + (gs-expr tp) + (gs-expr zp)))] [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) (sign (format "e2ρs~a~a~a" signature - (gs-expr tp-t ) - (gs-expr tp-u )))] + (gs-expr tp-t) + (gs-expr tp-u)))] [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) (sign (format "e2ρ~a~a_~a~a~a~a" signature m n shape-fn - (gs-expr tp-t ) - (gs-expr tp-u )))] + (gs-expr tp-t) + (gs-expr tp-u)))] [(tcomp-ext1-ρ-scalar f signature tp) (sign - (format "e1ρs~a~a" signature (gs-expr tp )))] + (format "e1ρs~a~a" signature (gs-expr tp)))] [(tcomp-ext1-ρ f signature m shape-fn tp) (sign (format "e1ρ~a~a~a~a" signature m shape-fn - (gs-expr tp )))] + (gs-expr tp)))] [(tcomp-reshape s tp) (sign (format "r~a~a" - s (gs-expr tp ) ))] + s (gs-expr tp)))] [(tcomp-ds-ref index) (sign - (format "dsr~a" index ))]))) + (format "dsr~a" index))]))) ;; Count references so that the tcomp AST nodes that refer to the same memory ;; location i.e. common AST nodes get extracted by let-binding them in the @@ -459,34 +453,34 @@ (`(,_ . ,rest-env) (exists-in-env? ft rest-env))))) (define generate-racket - (λ (t build-refs) - (gr-expr t build-refs))) + (λ (t) + (gr-expr t))) (define gr-expr - (λ (t build-refs) + (λ (t) (match t - [(tpromise tc _) (gr-expr tc build-refs)] + [(tpromise tc _) (gr-expr tc)] [v #:when (number? v) v] - [(tcomp) (gr-tcomp t build-refs)]))) + [(tcomp) (gr-tcomp t)]))) (define gr-tcomp - (λ (tc build-refs) + (λ (tc) (match tc [(tcomp-list->tensor lst) - (let ((instrs-list (map (λ (t) (gr-expr t build-refs)) lst))) + (let ((instrs-list (map (λ (t) (gr-expr t)) lst))) `(flat:list->tensor (list ,@instrs-list)))] [(tcomp-tref tp i) - (let ((instrs (gr-expr tp build-refs)) - (i-instrs (gr-expr i build-refs))) + (let ((instrs (gr-expr tp)) + (i-instrs (gr-expr i))) `(flat:tref ,instrs ,i-instrs))] [(tcomp-trefs tp b) - (let ((instrs (gr-expr tp build-refs)) - (b-instrs (gr-expr b build-refs))) + (let ((instrs (gr-expr tp)) + (b-instrs (gr-expr b))) `(rt:trefs ,instrs ,b-instrs))] [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let ((t0-instrs (gr-expr tp-t0 build-refs)) - (t1-instrs (gr-expr tp-t1 build-refs)) - (z-instrs (gr-expr tp-z build-refs))) + (let ((t0-instrs (gr-expr tp-t0)) + (t1-instrs (gr-expr tp-t1)) + (z-instrs (gr-expr tp-z))) (let ((b (if (zero? i) out0 out1))) `(let* ([b ,b] [v (data-segment-ref b)]) @@ -498,39 +492,39 @@ (data-segment-ref b)) (else v)))))] [(tcomp-ext1-∇ tp zp f sign m shape-fn) - (let ((t-instrs (gr-expr tp build-refs)) - (z-instrs (gr-expr zp build-refs))) + (let ((t-instrs (gr-expr tp)) + (z-instrs (gr-expr zp))) `(scalarize (flat-ext1-∇ ,f ,m ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,z-instrs))))] [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) - (let ((t-instrs (gr-expr tp-t build-refs)) - (u-instrs (gr-expr tp-u build-refs))) + (let ((t-instrs (gr-expr tp-t)) + (u-instrs (gr-expr tp-u))) `(,f ,t-instrs ,u-instrs))] [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) - (let ((t-instrs (gr-expr tp-t build-refs)) - (u-instrs (gr-expr tp-u build-refs))) + (let ((t-instrs (gr-expr tp-t)) + (u-instrs (gr-expr tp-u))) `(scalarize (flat-ext2-ρ ,f ,m ,n ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,u-instrs))))] [(tcomp-ext1-ρ-scalar f sign tp) - (let ((instrs (gr-expr tp build-refs))) + (let ((instrs (gr-expr tp))) `(,f ,instrs))] [(tcomp-ext1-ρ f sign m shape-fn tp) - (let ((instrs (gr-expr tp build-refs))) + (let ((instrs (gr-expr tp))) `(scalarize (flat-ext1-ρ ,f ,m ,shape-fn (ensure-flat ,instrs))))] [(tcomp-reshape s tp) - (let ((instrs (gr-expr tp build-refs))) + (let ((instrs (gr-expr tp))) `(flat ',s (flat-store ,instrs) (flat-offset ,instrs)))] [(tcomp-let lhs rhs body) - (let ((rhs-instrs (gr-expr rhs build-refs)) - (body-instrs (gr-expr body build-refs))) + (let ((rhs-instrs (gr-expr rhs)) + (body-instrs (gr-expr body))) `(let ((,lhs ,rhs-instrs)) ,body-instrs))] [(tcomp-var name) name] @@ -554,7 +548,7 @@ (define compile-tensor (λ (t) (display-compiler-trace 'Source-Tensor t) - (let-values (((eds-instrs ds build-refs) (extract-data-segment t))) + (let-values (((eds-instrs ds) (extract-data-segment t))) (display-compiler-trace 'Extract-Data-Segment-data ds) (display-compiler-trace 'Extract-Data-Segment-instructions eds-instrs) (let ((signature (generate-signature eds-instrs))) @@ -569,7 +563,7 @@ (display-compiler-trace 'Count-References counter) (let ((extracted (extract-common-subexpressions eds-instrs counter))) (display-compiler-trace 'Extract-Common-Subexpressions extracted) - (let ((rkt (generate-racket extracted build-refs))) + (let ((rkt (generate-racket extracted))) (display-compiler-trace 'Generate-Racket rkt) (hash-set! (cache) signature rkt) (values rkt ds)))))))))) @@ -577,14 +571,14 @@ ;;TODO: update this for new compiler passes (define compile-tensor/checks (λ (t) - (let-values (((eds-instrs ds build-refs) (extract-data-segment t))) + (let-values (((eds-instrs ds) (extract-data-segment t))) (flat:check-tensor-equal? (interp-tensor t) (interp-tensor eds-instrs)) (let ((counter (count-references t))) (let ((extracted (extract-common-subexpressions t counter))) (flat:check-tensor-equal? (interp-tensor t) (interp-tensor extracted)) (for/list ((cd (hash-values (count-references extracted)))) (check-equal? (counter-data-ref-count cd) 1)) - (let-values (((rkt env) (generate-racket extracted build-refs))) + (let-values (((rkt env) (generate-racket extracted))) (flat:check-tensor-equal? (interp-tensor extracted) (interp-racket rkt env)) (values rkt env))))))) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index f83a20e..33093ea 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -4,8 +4,8 @@ (require "0-lazy.rkt") (define-check (check-signatures-equal? t1 t2) - (let-values (((eds-instrs-1 ds1 build-refs1) (extract-data-segment t1)) - ((eds-instrs-2 ds2 build-refs2) (extract-data-segment t2))) + (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) + ((eds-instrs-2 ds2) (extract-data-segment t2))) (let ((sig1 (generate-signature eds-instrs-1)) (sig2 (generate-signature eds-instrs-2))) (with-check-info @@ -13,16 +13,14 @@ ('extracted-instrs-2 eds-instrs-2) ('data-segment-1 ds1) ('data-segment-2 ds2) - ('refs-for-build-tensor-nodes-1 build-refs1) - ('refs-for-build-tensor-nodes-2 build-refs2) ('signature-1 sig1) ('signature-2 sig2)) (unless (equal? sig1 sig2) (fail-check "signature mismatch")))))) (define-check (check-signatures-not-equal? t1 t2) - (let-values (((eds-instrs-1 ds1 build-refs1) (extract-data-segment t1)) - ((eds-instrs-2 ds2 build-refs2) (extract-data-segment t2))) + (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) + ((eds-instrs-2 ds2) (extract-data-segment t2))) (let ((sig1 (generate-signature eds-instrs-1)) (sig2 (generate-signature eds-instrs-2))) (with-check-info @@ -30,8 +28,6 @@ ('extracted-instrs-2 eds-instrs-2) ('data-segment-1 ds1) ('data-segment-2 ds2) - ('refs-for-build-tensor-nodes-1 build-refs1) - ('refs-for-build-tensor-nodes-2 build-refs2) ('signature-1 sig1) ('signature-2 sig2)) (when (equal? sig1 sig2) From 61421a6083ee73d206a7728fe76fca6cf4fa19c7 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 17 Nov 2023 12:29:35 -0500 Subject: [PATCH 67/83] [add-lazy]Optimize signature generation using bytes-join --- lazy/tensors/c3-compiler.rkt | 83 +++++++++++------------------------- 1 file changed, 25 insertions(+), 58 deletions(-) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 38ae3a6..198e3ed 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -124,26 +124,22 @@ (define hash-signatures? (make-parameter #f)) -;;TODO: Optimize sign by replacing it with this commented function -#; + (define sign (let ((xxh32-ctx (make-xxh32))) (λ ss (cond ((hash-signatures?) (xxh32-reset! xxh32-ctx 0) - (xxh32-update! xxh32-ctx (apply bytes-append ss)) + (xxh32-update! xxh32-ctx (apply bytes-join ss #"_")) (xxh32-digest xxh32-ctx)) (else (format "~a" ss)))))) -(define sign - (let ((xxh32-ctx (make-xxh32))) - (λ (s) - (cond - ((hash-signatures?) - (xxh32-reset! xxh32-ctx 0) - (xxh32-update! xxh32-ctx (string->bytes/utf-8 s)) - (format "~a" (xxh32-digest xxh32-ctx))) - (else (format "~a" s)))))) + +(define number->bytes + (λ (n) + (string->bytes/utf-8 (number->string n)))) + +(define string->bytes string->bytes/utf-8) (define generate-signature (λ (t) @@ -153,7 +149,7 @@ (λ (t) (match t (s #:when (number? s) - (sign (format "s~a" s))) + (sign #"s~a" (number->bytes s))) ((tpromise tc _) (gs-expr tc)) ((tcomp) (gs-tcomp t))))) @@ -161,61 +157,32 @@ (λ (tc) (match tc [(tcomp-list->tensor lst) - (let ((list-sig - (for/fold ((sig "")) - ((l lst) - (i (in-naturals 0))) - (string-append sig (gs-expr l))))) - (sign (format "l>t~a" list-sig)))] - + (apply sign #"l>t" (map gs-expr lst))] [(tcomp-tref tp i) - (sign - (format "tr~a~a" (gs-expr tp) (gs-expr i)))] + (sign "tr" (gs-expr tp) (gs-expr i))] [(tcomp-trefs tp b) - (sign - (format "trs~a~a" (gs-expr tp) (gs-expr b)))] + (sign #"trs" (gs-expr tp) (gs-expr b))] [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (sign - (format "e2∇~a~a_~a~a~a~a~a~a~a~a" - signature r0 r1 shape-fn - (gs-expr tp-t0) - (gs-expr tp-t1) - (gs-expr tp-z) - out0 out1 - i))] + (sign #"e2n" (string->bytes signature) + (number->bytes r0) (number->bytes r1) + (gs-expr tp-t0) (gs-expr tp-t1) (gs-expr tp-z) + (number->bytes out0) (number->bytes out1) (number->bytes i))] [(tcomp-ext1-∇ tp zp f signature m shape-fn) - (sign - (format "e1∇~a~a~a~a~a" - signature m shape-fn - (gs-expr tp) - (gs-expr zp)))] + (sign #"e1n" (string->bytes signature) (number->bytes m) + (gs-expr tp) (gs-expr zp))] [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) - (sign - (format "e2ρs~a~a~a" - signature - (gs-expr tp-t) - (gs-expr tp-u)))] + (sign #"e2rs" (string->bytes signature) (gs-expr tp-t) (gs-expr tp-u))] [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) - (sign - (format "e2ρ~a~a_~a~a~a~a" - signature m n shape-fn - (gs-expr tp-t) - (gs-expr tp-u)))] + (sign #"e2r" (string->bytes signature) (number->bytes m) (number->bytes m) + (gs-expr tp-t) (gs-expr tp-u))] [(tcomp-ext1-ρ-scalar f signature tp) - (sign - (format "e1ρs~a~a" signature (gs-expr tp)))] + (sign #"e1rs" (string->bytes signature) (gs-expr tp))] [(tcomp-ext1-ρ f signature m shape-fn tp) - (sign - (format "e1ρ~a~a~a~a" - signature m shape-fn - (gs-expr tp)))] + (sign #"e1r" (string->bytes signature) (number->bytes m) (gs-expr tp))] [(tcomp-reshape s tp) - (sign - (format "r~a~a" - s (gs-expr tp)))] + (apply sign #"r" gs-expr tp) (map number->bytes s)] [(tcomp-ds-ref index) - (sign - (format "dsr~a" index))]))) + (sign #"dsr" index)]))) ;; Count references so that the tcomp AST nodes that refer to the same memory ;; location i.e. common AST nodes get extracted by let-binding them in the From f2aecf263e5f8c199b9a6f816f7b628b01f5cb94 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 17 Nov 2023 12:30:51 -0500 Subject: [PATCH 68/83] [add-lazy]Improve caching in gradient descent --- malted/A-core.rkt | 4 +++- malted/D-gradient-descent.rkt | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/malted/A-core.rkt b/malted/A-core.rkt index 6933b20..4d50c79 100644 --- a/malted/A-core.rkt +++ b/malted/A-core.rkt @@ -1,6 +1,8 @@ #lang racket (require "../base.rkt") +;; TODO: This is not implementation independent. Figure out a fix +(require (only-in "../lazy/tensors.rkt" ↓)) (define dot-product (λ (w t) @@ -14,4 +16,4 @@ (include "test/test-A-core.rkt") -(provide dot-product dot-product-2-1) +(provide dot-product dot-product-2-1 ↓) diff --git a/malted/D-gradient-descent.rkt b/malted/D-gradient-descent.rkt index 33a3d28..b89bef4 100644 --- a/malted/D-gradient-descent.rkt +++ b/malted/D-gradient-descent.rkt @@ -17,7 +17,7 @@ (λ (obj theta) (let ((ctr 0)) (let ((f (λ (big-theta) - (map update + (map (λ (pa g) (map* ↓ (update pa g))) big-theta (gradient-of obj (map deflate big-theta)))))) From 718d21d65d132a8b28c72213b57f3cb3acf561ef Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 18 Nov 2023 04:14:22 -0500 Subject: [PATCH 69/83] [add-lazy]Fix signature generation --- lazy/autodiff/A-autodiff.rkt | 3 +-- lazy/tensors/c3-compiler.rkt | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index e76b165..e7399a1 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -61,8 +61,7 @@ (define ∇ (λ (f theta) (let ((wrt (map* dual* theta))) - ;; TODO: try forcing (f wrt) to see if it fixes caching issues - (∇-once (f wrt) #;(↓ (f wrt)) wrt)))) + (∇-once (f wrt) wrt)))) (define ∇¹ (λ (f) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 198e3ed..80a4d27 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -123,7 +123,7 @@ (values (tcomp-reshape s tp^) dss^))]))) (define hash-signatures? - (make-parameter #f)) + (make-parameter #t)) (define sign (let ((xxh32-ctx (make-xxh32))) @@ -131,8 +131,8 @@ (cond ((hash-signatures?) (xxh32-reset! xxh32-ctx 0) - (xxh32-update! xxh32-ctx (apply bytes-join ss #"_")) - (xxh32-digest xxh32-ctx)) + (xxh32-update! xxh32-ctx (bytes-join ss #"_")) + (number->bytes (xxh32-digest xxh32-ctx))) (else (format "~a" ss)))))) (define number->bytes @@ -159,7 +159,7 @@ [(tcomp-list->tensor lst) (apply sign #"l>t" (map gs-expr lst))] [(tcomp-tref tp i) - (sign "tr" (gs-expr tp) (gs-expr i))] + (sign #"tr" (gs-expr tp) (gs-expr i))] [(tcomp-trefs tp b) (sign #"trs" (gs-expr tp) (gs-expr b))] [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) @@ -180,9 +180,9 @@ [(tcomp-ext1-ρ f signature m shape-fn tp) (sign #"e1r" (string->bytes signature) (number->bytes m) (gs-expr tp))] [(tcomp-reshape s tp) - (apply sign #"r" gs-expr tp) (map number->bytes s)] + (apply sign #"r" (gs-expr tp) (map number->bytes s))] [(tcomp-ds-ref index) - (sign #"dsr" index)]))) + (sign #"dsr" (number->bytes index))]))) ;; Count references so that the tcomp AST nodes that refer to the same memory ;; location i.e. common AST nodes get extracted by let-binding them in the From 750c41745475a2ce0e20c32e4685251fb90e7d87 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 24 Nov 2023 19:45:10 -0500 Subject: [PATCH 70/83] [add-lazy]Refactor test cases --- lazy/tensors/1-reflect.rkt | 14 +- lazy/tensors/B-test-programs.rkt | 369 ++++++++++++++++++------- lazy/tensors/test/test-1-reflect.rkt | 243 +++------------- lazy/tensors/test/test-c3-compiler.rkt | 1 + 4 files changed, 318 insertions(+), 309 deletions(-) diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt index 9dbae5f..aad9591 100644 --- a/lazy/tensors/1-reflect.rkt +++ b/lazy/tensors/1-reflect.rkt @@ -10,16 +10,14 @@ (require (only-in "c2-interpreter.rkt" interp-racket)) (define ↓ - (lambda (tp (print? #f)) - (when print? - (printf "~n####PP tensor: ") - (pretty-print tp)) + (lambda (tp) (match tp + [(tpromise v _) + #:when (or (flat:flat? v) (number? v)) + v] [(tpromise t _) - #:when (or (flat:flat? t) (number? t) (tcomp? t)) - - (let-values (((instrs data-segment) - (compile-tensor t))) + #:when (tcomp? t) + (let-values (((instrs data-segment) (compile-tensor t))) (let ((res (interp-racket instrs data-segment))) (set-tpromise-tensor! tp res) res))] diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt index f75b2b1..f247b89 100644 --- a/lazy/tensors/B-test-programs.rkt +++ b/lazy/tensors/B-test-programs.rkt @@ -11,81 +11,321 @@ (list->tensor l))) (struct test-program-data (prog-thunk eval-res) #:transparent) +(struct eval-res-1 (res) #:transparent) +(struct eval-res-2 (res1 res2) #:transparent) + + +;; Care must be taken while calling get-test-program within the +;; test-program-data thunk because it might lead to an infinite loop. (define test-programs (hasheqv 'tensor-r1-0 (test-program-data (λ () (tensor 1 2 3)) - (flat:tensor 1 2 3)) + (eval-res-1 (flat:tensor 1 2 3))) 'tensor-r1-1 (test-program-data (λ () (tensor 1 2 3 4 5)) - (flat:tensor 1 2 3 4 5)) + (eval-res-1 (flat:tensor 1 2 3 4 5))) + 'tensor-r1-2 (test-program-data + (λ () + (tensor 3.0 4.0 5.0)) + (eval-res-1 (flat:tensor 3.0 4.0 5.0))) 'tensor-r2-0 (test-program-data (λ () (tensor (tensor 1 2 3) (tensor 4 5 6))) - (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6))) + (eval-res-1 (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6)))) + 'tensor-r2-1 (test-program-data + (λ () + (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) + (eval-res-1 + (flat:reshape '(2 3) (flat:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) + 'build-tensor-r1-0 (test-program-data + (λ () + (build-tensor '(6) + (λ (i) (* 3.0 (car i))))) + (eval-res-1 (flat:build-tensor '(6) + (λ (i) (* 3.0 (car i)))))) 'build-tensor-r2-0 (test-program-data (λ () (build-tensor '(5 6) (λ (i) (match-define `(,x ,y) i) (* 2.0 (+ (* x 6) y))))) - (flat:build-tensor '(5 6) - (λ (i) - (match-define `(,x ,y) i) - (* 2.0 (+ (* x 6) y))))) + (eval-res-1 (flat:build-tensor '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y)))))) + 'build-tensor-r2-1 (test-program-data + (λ () + (build-tensor '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y))))) + (eval-res-1 (flat:build-tensor '(3 6) + (λ (i) + (match-define `(,x ,y) i) + (* 3.0 (+ (* x 6) y)))))) 'build-tensor-r3-0 (test-program-data (λ () (build-tensor '(2 3 4) (λ (i) (match-define `(,x ,y ,z) i) (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) - (flat:build-tensor '(2 3 4) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) + (eval-res-1 (flat:build-tensor + '(2 3 4) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2 (+ (* x 12) (* y 4) (* 1 z))))))) 'build-tensor-r3-1 (test-program-data (λ () (build-tensor '(3 5 6) (λ (i) (match-define `(,x ,y ,z) i) (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) - (flat:build-tensor '(3 5 6) - (λ (i) - (match-define `(,x ,y ,z) i) - (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) + (eval-res-1 (flat:build-tensor + '(3 5 6) + (λ (i) + (match-define `(,x ,y ,z) i) + (* 2.0 (+ (* x 30) (* y 6) (* 1 z))))))) 'extract-ds-once-tref (test-program-data (λ () (let ((n (tref (get-test-program 'tensor-r1-0) 1))) (+-ρ n n))) - 4) + (eval-res-1 4)) 'extract-ds-once-trefs (test-program-data (λ () - (let ((tp (trefs (get-test-program 'tensor-r1-0) '(0 2)))) + (let ((tp (trefs (get-test-program 'tensor-r1-0) + '(0 2)))) (+-ρ tp tp))) - (flat:tensor 2 6)) + (eval-res-1 (flat:tensor 2 6))) 'built-tensor (test-program-data (λ () - (build-tensor test-build-shape - (λ (i) - (let ([row (car i)] - [column (cadr i)]) - (+ (* (sub1 (car test-build-shape)) - row) - column))))) - (flat:tensor (flat:tensor 0 1 2) - (flat:tensor 3 4 5) - (flat:tensor 6 7 8) - (flat:tensor 9 10 11))) + (let ((test-build-shape '(4 3))) + (build-tensor test-build-shape + (λ (i) + (let ([row (car i)] + [column (cadr i)]) + (+ (* (sub1 (car test-build-shape)) + row) + column)))))) + (eval-res-1 (flat:tensor (flat:tensor 0 1 2) + (flat:tensor 3 4 5) + (flat:tensor 6 7 8) + (flat:tensor 9 10 11)))) 'multi-built-tensor (test-program-data (λ () (+-ρ (get-test-program 'build-tensor-r2-0) (tref (get-test-program 'build-tensor-r3-1) 0))) - ((flat:ext2-ρ * 0 0) 2 (flat:build-tensor '(5 6) - (λ (i) - (match-define `(,x ,y) i) - (* 2.0 (+ (* x 6) y)))))) + (eval-res-1 ((flat:ext2-ρ * 0 0) 2 + (flat:build-tensor + '(5 6) + (λ (i) + (match-define `(,x ,y) i) + (* 2.0 (+ (* x 6) y))))))) + 'tcomp-tref (test-program-data + (λ () + (make-tref-test-program (get-test-program 'tensor-r1-0))) + (eval-res-1 3)) + 'tcomp-tref-nested (test-program-data + (λ () + (tref (tref (get-test-program 'tensor-r2-0) 0) 2)) + (eval-res-1 3)) + 'tcomp-list->tensor (test-program-data + (λ () + (make-list->tensor-test-program '(5 6 7 8))) + (eval-res-1 (flat:tensor 5 6 7 8))) + 'tcomp-nested-list->tensor (test-program-data + (λ () + (list->tensor `(,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0)))) + (eval-res-1 (flat:tensor + (flat:tensor 1 2 3) + (flat:tensor 1 2 3) + (flat:tensor 1 2 3)))) + 'tcomp-trefs (test-program-data + (λ () + (trefs (get-test-program 'built-tensor) '(0 2))) + (eval-res-1 (flat:tensor (flat:tensor 0 1 2) + (flat:tensor 6 7 8)))) + 'tcomp-reshape (test-program-data + (λ () + (reshape '(3 2 1) + (trefs (get-test-program 'built-tensor) '(1 3)))) + (eval-res-1 (flat:tensor (flat:tensor (flat:tensor 3) + (flat:tensor 4)) + (flat:tensor (flat:tensor 5) + (flat:tensor 9)) + (flat:tensor (flat:tensor 10) + (flat:tensor 11))))) + 'sum (test-program-data + (λ () + (sum (get-test-program 'tensor-r2-0))) + (eval-res-1 (flat:tensor 6.0 15.0))) + 'sum-nested (test-program-data + (λ () + (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) + (eval-res-1 (flat:tensor 4.0 6.0 5.0))) + 'id (test-program-data + (λ () + (id-ρ (get-test-program 'tensor-r2-0))) + (eval-res-1 (flat:tensor (flat:tensor 1 2 3) + (flat:tensor 4 5 6)))) + 'id-scalar (test-program-data + (λ () + (id-ρ (sum (tensor 4 5 6)))) + (eval-res-1 15)) + 'sqr (test-program-data + (λ () + (*-ρ (get-test-program 'build-tensor-r3-0) + (get-test-program 'build-tensor-r3-0))) + (eval-res-1 (flat:reshape + '(2 3 4) + (flat:tensor + 0 4 16 36 + 64 100 144 196 + 256 324 400 484 + 576 676 784 900 + 1024 1156 1296 1444 + 1600 1764 1936 2116)))) + 'r-1-2 (test-program-data + (λ () + (*-2-1 (get-test-program 'build-tensor-r2-0) + (get-test-program 'build-tensor-r1-0))) + (eval-res-1 (flat:reshape + '(5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0)))) + 'r-3-4 (test-program-data + (λ () + (*-2-1 (get-test-program 'build-tensor-r3-1) + (get-test-program 'build-tensor-r2-1))) + (eval-res-1 (flat:reshape + '(3 5 6) + (flat:tensor + 0 6.0 24.0 54.0 96.0 150.0 + 0 42.0 96.0 162.0 240.0 330.0 + 0 78.0 168.0 270.0 384.0 510.0 + 0 114.0 240.0 378.0 528.0 690.0 + 0 150.0 312.0 486.0 672.0 870.0 + + 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 + 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 + 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 + 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 + 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 + + 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 + 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 + 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 + 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 + 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0)))) + 'r-sum-2-scalar (test-program-data + (λ () + (*-ρ (sum (get-test-program 'build-tensor-r1-0)) + (sum (tensor 2 3 4)))) + (eval-res-1 405.0)) + 'tcomp-dsqr-r1 (test-program-data + (λ () + (d-sqr r1-td (one-like r1-td))) + (eval-res-1 (flat:tensor 6.0 8.0 10.0))) + 'gsqr (test-program-data + (λ () + (let ([r2-td (get-test-program 'tensor-r2-1)]) + (d-sqr r2-td (one-like r2-td)))) + (eval-res-1 (flat:reshape + '(2 3) + (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + 'g+ (test-program-data + (λ () + (d+ 2.0 3.0 1.0)) + (eval-res-2 1.0 1.0)) + 'g-twice (test-program-data + (λ () + (d+ r1-td r1-td (one-like r1-td))) + (eval-res-2 (flat:tensor 1.0 1.0 1.0) + (flat:tensor 1.0 1.0 1.0))) + 'g+-r1-r2 (test-program-data + (λ () + (let ((r2-td (get-test-program 'tensor-r2-1))) + (d+ r1-td r2-td (one-like r2-td)))) + (eval-res-2 (flat:tensor 2.0 2.0 2.0) + (flat:reshape + '(2 3) + (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + 'g* (test-program-data + (λ () + (*∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0 1.0))) + (eval-res-2 (flat:tensor 1.0 2.0 3.0) + (flat:tensor 2.0 3.0 4.0))) + 'gsum-r1 (test-program-data + (λ () + (sum-∇ (tensor 2.0 3.0 4.0) + 1.0)) + (eval-res-1 (flat:tensor 1.0 1.0 1.0))) + 'gsum-r2 (test-program-data + (λ () + (sum-∇ (tensor (tensor 2.0 3.0 4.0) + (tensor 2.0 3.0 4.0)) + (tensor 2.0 1.0))) + (eval-res-1 (flat:tensor (flat:tensor 2.0 2.0 2.0) + (flat:tensor 1.0 1.0 1.0)))) + 'gs2-r1 (test-program-data + (λ () + (s2-∇ (tensor 2.0 3.0 4.0) + (tensor 1.0 2.0 3.0) + (tensor 1.0 1.0))) + (eval-res-2 (flat:tensor 1.0 1.0 1.0) + (flat:tensor 1.0 1.0 1.0))) + 'gs2-r3 (test-program-data + (λ () + (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 5.0 6.0 6.0) + (tensor 7.0 8.0 6.0)) + (tensor (tensor 8.0 7.0 6.0) + (tensor 6.0 5.0 6.0))) + (tensor (tensor (tensor 6.0 8.0 6.0) + (tensor 3.0 4.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 8.0 2.0 6.0)) + (tensor (tensor 9.0 7.0 6.0) + (tensor 5.0 1.0 6.0))) + (tensor (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0)) + (tensor (tensor 1.0 1.0) + (tensor 1.0 1.0))))) + (eval-res-2 (flat:reshape '(3 2 3) + (flat:list->tensor (make-list 18 1.0))) + (flat:reshape '(3 2 3) + (flat:list->tensor (make-list 18 1.0))))) + 'env-flat-scalar (test-program-data + (λ () + ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) + (list (tensor 1.0) 3.0))) + (eval-res-1 (flat:tensor 3.0))) + 'common-subexpression (test-program-data + (λ () + (let ((t (tref (tensor 1 2 3) 0))) + (tensor t t))) + (eval-res-1 (flat:tensor 1.0 1.0))) + 'nested-common-subexpression (test-program-data + (λ () + (let ((t1 (tref (tensor (tensor 1 2 3) + (tensor 4 5 6)) + 0))) + (let ((t0 (tref t1 0))) + (tensor t0 t0)))) + (eval-res-1 (flat:tensor 1.0 1.0))) )) (define get-test-program @@ -95,19 +335,6 @@ (λ (name) (test-program-data-eval-res (hash-ref test-programs name)))) -(define test-tcomp-tref (make-tref-test-program (get-test-program 'tensor-r1-0))) -(define test-tcomp-tref-nested (tref (tref (get-test-program 'tensor-r2-0) 0) 2)) -(define test-list->tensor (make-list->tensor-test-program '(5 6 7 8))) -(define test-nested-list->tensor - (list->tensor `(,(get-test-program 'tensor-r1-0) - ,(get-test-program 'tensor-r1-0) - ,(get-test-program 'tensor-r1-0)))) -(define test-build-shape '(4 3)) - -(define test-refs '(0 2)) -(define test-trefs (trefs (get-test-program 'built-tensor) test-refs)) -(define test-reshape (reshape '(3 2 1) (trefs (get-test-program 'built-tensor) '(1 3)))) - (define sum-f (λ (in-v iᵢ sᵢ out-v iₒ sₒ) (vset! out-v iₒ @@ -115,19 +342,11 @@ (+ sum (vref in-v i)))))) (define sum (ext1-ρ sum-f 1 (λ (s) '()) #t)) -(define test-tp-sum (sum (get-test-program 'tensor-r2-0))) -(define test-tp-sum-nested (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) (define id-f (lambda (v) v)) (define id-ρ (ext1-ρ id-f 1 (λ (s) s))) -(define test-tp-id (id-ρ (get-test-program 'tensor-r2-0))) -(define test-tp-id-scalar (id-ρ (sum (tensor 4 5 6)))) - -(define t0 - (get-test-program 'build-tensor-r3-0)) (define *-ρ (ext2-ρ * 0 0)) -(define t0sqr (*-ρ t0 t0)) (define *-2-1-f (λ (v0 i0 s0 v1 i1 s1 vout iout sout) @@ -136,35 +355,13 @@ (* (vref v0 (+ i0 j0)) (vref v1 (+ i1 (modulo j0 s1)))))))) -(define t1 - (get-test-program 'build-tensor-r2-0)) - (define t2 - (build-tensor '(6) - (λ (i) (* 3.0 (car i))))) + (get-test-program 'build-tensor-r1-0)) (define *-2-1 (ext2-ρ *-2-1-f 2 1 (λ (s0 s1) s0) #t)) -(define r-1-2 - (*-2-1 t1 t2)) - -(define t3 - (get-test-program 'build-tensor-r3-1)) - -(define t4 - (build-tensor '(3 6) - (λ (i) - (match-define `(,x ,y) i) - (* 3.0 (+ (* x 6) y))))) - -(define r-3-4 - (*-2-1 t3 t4)) - -(define r-sum-2-scalar (*-ρ (sum t2) (sum (tensor 2 3 4)))) - -(define r1-td (tensor 3.0 4.0 5.0)) -(define r2-td (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) +(define r1-td (get-test-program 'tensor-r1-2)) (define +ᶠ +) (define +ᵈ (λ (a b z) (values z z))) @@ -179,8 +376,6 @@ (λ (t) (build-tensor (shape t) (λ (_) 1.0)))) -(define tcomp-dsqr-r1 (d-sqr r1-td (one-like r1-td))) - (define d+ (ext2-∇ +ᵈ 0 0)) (define *∇ (ext2-∇ (λ (a b z) (values (* z b) (* z a))) 0 0)) @@ -201,20 +396,6 @@ (vset! g1 i (vref vz (+ iz 1)))))) (define s2-∇ (ext2-∇ s2-d 1 1 (λ (s0 s1) (list 2)) #t)) -(define test-env-flat-scalar - ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) - (list (tensor 1.0) 3.0))) - -;; Check common subexpression introduced by let is not repeated -(define test-common-subexpr - (let ((t (tref (tensor 1 2 3) 0))) - (tensor t t))) - -(define test-common-nested-subexprs - (let ((t1 (tref (tensor (tensor 1 2 3) (tensor 4 5 6)) 0))) - (let ((t0 (tref t1 0))) - (tensor t0 t0)))) - (define random-tensor (λ (s) (build-tensor s (λ (tidx) (random 10))))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt index 1724f17..c5f1d78 100644 --- a/lazy/tensors/test/test-1-reflect.rkt +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -31,24 +31,45 @@ ;;TODO: Move all check-compiler-invariant checks to the test file for ;;c3-compiler.rkt file. - ;;TODO: Refactor all test cases to use get-test-program so that ↓ doesn't - ;;mutate the programs defined in B-test-programs + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (match res + ((eval-res-1 res) + (let* ((tp (th)) + (forced (↓ tp))) + (flat:check-tensor-equal? + forced res + (format "Expected result doesn't match in test case ~a" + test-name)) + (check-false (tcomp? (tpromise-tensor tp))) + (check-equal? (tpromise-shape tp) (flat:shape forced)))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th)) + ((forced1) (↓ tp1)) + ((forced2) (↓ tp2))) + (flat:check-tensor-equal? + forced1 res1 + (format "Expected first result doesn't match in test case ~a" + test-name)) + (check-false (tcomp? (tpromise-tensor tp1))) + (check-equal? (tpromise-shape tp1) (flat:shape forced1)) + (flat:check-tensor-equal? + forced2 res2 + (format "Expected second result doesn't match in test case ~a" + test-name)) + (check-false (tcomp? (tpromise-tensor tp2))) + (check-equal? (tpromise-shape tp2) (flat:shape forced2)))))) + + (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) - (check-compiler-invariants test-tensor-r1-0) - (check-true (flat:flat? (tpromise-tensor test-tensor-r1-0))) - (flat:check-tensor-equal? (↓ test-tensor-r1-0) - (get-test-eval-res 'tensor-r1-0)) (check-true (flat:flat? (tpromise-tensor test-tensor-r1-0))) (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) - (check-compiler-invariants test-tcomp-tref) - (check-equal? (↓ test-tcomp-tref) 3) + (define test-tcomp-tref (get-test-program 'tcomp-tref)) (check-exn exn:fail? (λ () (tref test-tensor-r1-0 5))) (define test-nested-tensor (get-test-program 'tensor-r2-0)) - (check-compiler-invariants test-tcomp-tref-nested) - (check-equal? (↓ test-tcomp-tref-nested) 3) (check-exn exn:fail? (λ () (tref (tref test-nested-tensor 2) 0))) (check-exn exn:fail? (λ () (tref test-nested-tensor 2))) (check-exn exn:fail? (λ () (tensor test-nested-tensor test-nested-tensor test-tensor-r1-0))) @@ -56,11 +77,8 @@ (check-equal? (tlen test-tensor-r1-0) 3) (check-equal? (tlen test-nested-tensor) 2) - (check-compiler-invariants test-list->tensor) - (check-equal? (flat:flat-store (↓ test-list->tensor)) (vector 5 6 7 8)) - (check-compiler-invariants test-nested-list->tensor) - (check-equal? (flat:flat-store (↓ test-nested-list->tensor)) - (vector 1 2 3 1 2 3 1 2 3)) + (define test-nested-list->tensor + (get-test-program 'tcomp-nested-list->tensor)) (check-equal? (tpromise-shape test-nested-list->tensor) '(3 3)) (define test-tcomp-partial-eval @@ -76,210 +94,21 @@ test-tensor-r1-0))) 1) 2))) - (check-compiler-invariants test-tcomp-partial-eval) (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) (↓ (tensor 1 2 3))) - (define test-built-tensor (get-test-program 'built-tensor)) - (check-compiler-invariants test-built-tensor) - (check-equal? (tpromise-shape test-built-tensor) test-build-shape) - (flat:check-tensor-equal? (↓ test-built-tensor) - (↓ (tensor (tensor 0 1 2) - (tensor 3 4 5) - (tensor 6 7 8) - (tensor 9 10 11)))) - - (check-compiler-invariants test-trefs) + (define test-trefs (get-test-program 'tcomp-trefs)) (check-true (tcomp? (tpromise-tensor test-trefs))) - (check-equal? (tpromise-shape test-trefs) - (flat:flat-shape (↓ test-trefs))) - (flat:check-tensor-equal? (↓ test-trefs) - (↓ (tensor (tensor 0 1 2) - (tensor 6 7 8)))) (check-exn exn:fail? (λ () (trefs test-nested-tensor '(0 4)))) - (check-compiler-invariants test-reshape) - (flat:check-tensor-equal? (↓ test-reshape) - (↓ (tensor (tensor (tensor 3) - (tensor 4)) - (tensor (tensor 5) - (tensor 9)) - (tensor (tensor 10) - (tensor 11))))) + (define test-reshape (get-test-program 'tcomp-reshape)) (check-exn exn:fail? (λ () (reshape '(4 5) test-reshape))) - (check-compiler-invariants test-tp-sum) - (flat:check-tensor-equal? (↓ test-tp-sum) - (↓ (tensor 6.0 15.0))) - - (check-compiler-invariants test-tp-sum-nested) - (flat:check-tensor-equal? (↓ test-tp-sum-nested) - (↓ (tensor 4.0 6.0 5.0))) - - (check-compiler-invariants test-tp-id) - (flat:check-tensor-equal? (↓ test-tp-id) - (↓ (tensor (tensor 1 2 3) - (tensor 4 5 6)))) - - (check-compiler-invariants test-tp-id-scalar) - (check-equal? (↓ test-tp-id-scalar) 15.0) - - (check-compiler-invariants t0sqr) - (flat:check-tensor-equal? (↓ t0sqr) - (flat:reshape - '(2 3 4) - (flat:tensor - 0 4 16 36 - 64 100 144 196 - 256 324 400 484 - 576 676 784 900 - 1024 1156 1296 1444 - 1600 1764 1936 2116))) - - (check-equal? (tpromise-shape r-1-2) '(5 6)) - (check-compiler-invariants r-1-2) - (flat:check-tensor-equal? (↓ r-1-2) - (flat:reshape - '(5 6) - (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0))) - - (check-equal? (tpromise-shape r-3-4) '(3 5 6)) - (check-compiler-invariants r-3-4) - (flat:check-tensor-equal? (↓ r-3-4) - (flat:reshape - '(3 5 6) - (flat:tensor - 0 6.0 24.0 54.0 96.0 150.0 - 0 42.0 96.0 162.0 240.0 330.0 - 0 78.0 168.0 270.0 384.0 510.0 - 0 114.0 240.0 378.0 528.0 690.0 - 0 150.0 312.0 486.0 672.0 870.0 - - 1080.0 1302.0 1536.0 1782.0 2040.0 2310.0 - 1296.0 1554.0 1824.0 2106.0 2400.0 2706.0 - 1512.0 1806.0 2112.0 2430.0 2760.0 3102.0 - 1728.0 2058.0 2400.0 2754.0 3120.0 3498.0 - 1944.0 2310.0 2688.0 3078.0 3480.0 3894.0 - - 4320.0 4758.0 5208.0 5670.0 6144.0 6630.0 - 4752.0 5226.0 5712.0 6210.0 6720.0 7242.0 - 5184.0 5694.0 6216.0 6750.0 7296.0 7854.0 - 5616.0 6162.0 6720.0 7290.0 7872.0 8466.0 - 6048.0 6630.0 7224.0 7830.0 8448.0 9078.0))) - - (check-compiler-invariants r-sum-2-scalar) - (flat:check-tensor-equal? (↓ r-sum-2-scalar) 405.0) - - (check-compiler-invariants tcomp-dsqr-r1) - (flat:check-tensor-equal? (↓ tcomp-dsqr-r1) - (flat:tensor 6.0 8.0 10.0)) - - (let ((gsqr (d-sqr r2-td (one-like r2-td)))) - (check-compiler-invariants gsqr) - (flat:check-tensor-equal? (↓ gsqr) - (flat:reshape - '(2 3) - (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) - - (let-values (((da db) (d+ 2.0 3.0 1.0))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (↓ da) 1.0) - (check-compiler-invariants db) - (flat:check-tensor-equal? (↓ db) 1.0)) - - (let-values (((da db) (d+ r1-td r1-td (one-like r1-td)))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (↓ da) - (flat:tensor 1.0 1.0 1.0)) - (check-compiler-invariants db) - (flat:check-tensor-equal? (↓ db) - (flat:tensor 1.0 1.0 1.0))) - - (let-values (((da db) (d+ r1-td r2-td (one-like r2-td)))) - (check-compiler-invariants da) - (flat:check-tensor-equal? (↓ da) - (flat:tensor 2.0 2.0 2.0)) - (check-compiler-invariants db) - (flat:check-tensor-equal? (↓ db) - (flat:reshape - '(2 3) - (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) - - (let-values (((gt gu) (*∇ (tensor 2.0 3.0 4.0) - (tensor 1.0 2.0 3.0) - (tensor 1.0 1.0 1.0)))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 2.0 3.0))) - (flat:check-tensor-equal? (↓ gu) (↓ (tensor 2.0 3.0 4.0)))) - - (let ((gt (sum-∇ (tensor 2.0 3.0 4.0) - 1.0))) - (check-compiler-invariants gt) - (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 1.0 1.0)))) - - (let ((gt (sum-∇ (tensor (tensor 2.0 3.0 4.0) - (tensor 2.0 3.0 4.0)) - (tensor 2.0 1.0)))) - (check-compiler-invariants gt) - (flat:check-tensor-equal? (↓ gt) (↓ (tensor (tensor 2.0 2.0 2.0) - (tensor 1.0 1.0 1.0))))) - (let-values (((gt gu) (s2-∇ (tensor 2.0 3.0 4.0) - (tensor 1.0 2.0 3.0) - (tensor 1.0 1.0)))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? (↓ gt) (↓ (tensor 1.0 1.0 1.0))) - (flat:check-tensor-equal? (↓ gu) (↓ (tensor 1.0 1.0 1.0)))) - (let-values (((gt gu) (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) - (tensor 3.0 4.0 6.0)) - (tensor (tensor 5.0 6.0 6.0) - (tensor 7.0 8.0 6.0)) - (tensor (tensor 8.0 7.0 6.0) - (tensor 6.0 5.0 6.0))) - (tensor (tensor (tensor 6.0 8.0 6.0) - (tensor 3.0 4.0 6.0)) - (tensor (tensor 9.0 7.0 6.0) - (tensor 8.0 2.0 6.0)) - (tensor (tensor 9.0 7.0 6.0) - (tensor 5.0 1.0 6.0))) - (tensor (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)) - (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)) - (tensor (tensor 1.0 1.0) - (tensor 1.0 1.0)))))) - (check-compiler-invariants gt) - (check-compiler-invariants gu) - (flat:check-tensor-equal? - (↓ gt) - (↓ (reshape '(3 2 3) (list->tensor (make-list 18 1.0))))) - (flat:check-tensor-equal? - (↓ gu) - (↓ (reshape '(3 2 3) (list->tensor (make-list 18 1.0)))))) - - (check-compiler-invariants test-env-flat-scalar) - (flat:check-tensor-equal? (↓ test-env-flat-scalar) - (flat:tensor 3.0)) - - (check-compiler-invariants test-common-subexpr) - (flat:check-tensor-equal? (↓ test-common-subexpr) - (flat:tensor 1.0 1.0)) - - (check-compiler-invariants test-common-nested-subexprs) - (flat:check-tensor-equal? (↓ test-common-nested-subexprs) - (flat:tensor 1.0 1.0)) - (check-pred (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) (vector->list (flat:flat-store (↓ test-build-random))) "Side-effect of generating random tensor must only be run once") (flat:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) - (get-test-eval-res 'multi-built-tensor)) + (eval-res-1-res (get-test-eval-res 'multi-built-tensor))) ) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index 33093ea..76c49cb 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -34,6 +34,7 @@ (fail-check "signatures musn't match")))))) (define test-tensor-r1-1 (get-test-program 'tensor-r1-1)) + (define test-tcomp-tref (get-test-program 'tcomp-tref)) (check-signatures-equal? test-tcomp-tref (make-tref-test-program test-tensor-r1-1)) (check-signatures-not-equal? test-tcomp-tref From 4b45e4999a5fc876329e4da3892ad2dc67f6b1e1 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 25 Nov 2023 12:21:22 -0500 Subject: [PATCH 71/83] [add-lazy]Optimize racket eval --- lazy/tensors/c3-compiler.rkt | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 80a4d27..569ed61 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -4,6 +4,7 @@ (require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require (only-in "c1-racket-runtime.rkt" runtime)) (require rackunit) (require file/xxhash32) @@ -530,11 +531,17 @@ (display-compiler-trace 'Count-References counter) (let ((extracted (extract-common-subexpressions eds-instrs counter))) (display-compiler-trace 'Extract-Common-Subexpressions extracted) - (let ((rkt (generate-racket extracted))) - (display-compiler-trace 'Generate-Racket rkt) + (let* ((gr (generate-racket extracted)) + (rkt (compile-racket gr))) + (display-compiler-trace 'Generate-Racket gr) (hash-set! (cache) signature rkt) (values rkt ds)))))))))) +(define compile-racket + (λ (r) + (parameterize ([current-namespace runtime]) + (compile-syntax (expand r))))) + ;;TODO: update this for new compiler passes (define compile-tensor/checks (λ (t) From 1d27b91fc8e4fd11f2d9cb6026e2e84194194aae Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 22 Dec 2023 23:53:52 -0500 Subject: [PATCH 72/83] [add-lazy]Build DS and signature while building AST --- lazy/autodiff/D-test-helpers.rkt | 5 +- lazy/tensors/0-lazy.rkt | 140 ++++--- lazy/tensors/1-reflect.rkt | 20 +- lazy/tensors/B-test-programs.rkt | 13 +- lazy/tensors/c0-ast.rkt | 242 +++++++++++- lazy/tensors/c1-racket-runtime.rkt | 18 +- lazy/tensors/c2-interpreter.rkt | 2 +- lazy/tensors/c3-compiler.rkt | 518 +++++++++++-------------- lazy/tensors/test/test-1-reflect.rkt | 67 ++-- lazy/tensors/test/test-c3-compiler.rkt | 87 +++-- 10 files changed, 658 insertions(+), 454 deletions(-) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index 1b1df9f..caac553 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -7,6 +7,7 @@ (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) + ;; TODO: This code ahould work even after removing the ↓ call (let* ((y (↓ (apply fn args))) (g (↓ (apply (∇¹ fn) args))) (ans-ρ (ρ ans))) @@ -14,10 +15,10 @@ ((and (equal-wt? ans-ρ (ρ y)) (equal-wt? grads (ρ g))) (void)) ((equal-wt? ans-ρ (ρ y)) - (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~s~%" + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%" (ρ g) grads))) (else - (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~s~%" + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%" (ρ y) ans-ρ)))))) (define-syntax check-ρ-∇ diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index cc0d047..b0afa55 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -58,17 +58,6 @@ instructions refering to the same gensym variable (λ args (list->tpromise args))) #; -(: tensor-inner-flat (-> (Listof (U tpromise Number)) - (U flat tcomp-list->tensor))) -(define tensor-inner-flat - (λ (lst) - (cond - [(andmap number? lst) (apply flat:tensor lst)] - [(andmap (λ (v) (and (tpromise? v) (flat:flat? (tpromise-tensor v)))) lst) - (apply flat:tensor (map tpromise-tensor lst))] - [else (tcomp-list->tensor lst)]))) - -#; (: ensure-shape (-> (U (Listof tpromise) (Listof Number)) Void)) (define ensure-shape (λ (args) @@ -98,18 +87,31 @@ instructions refering to the same gensym variable "Cannot construct a tensor out of these elements: ~a~%" args))))) +#; +(: tensor-inner-flat (-> (Listof (U tpromise Number)) + (U flat tcomp-list->tensor))) +(define tensor-inner-flat + (λ (lst) + (cond + [(andmap number? lst) (apply flat:tensor lst)] + [(andmap tpromise-flat? lst) + (apply flat:tensor + (for/list ((tp-flat lst)) + (car (unbox (tpromise-dst tp-flat)))))] + [else lst]))) + (define list->tpromise (λ (lst) (ensure-shape lst) - (let ((inner-flat (tensor-inner-flat lst))) + (let ((inner-tensor (tensor-inner-flat lst))) (cond - ((flat? inner-flat) - (tpromise inner-flat (flat-shape inner-flat))) + ((flat? inner-tensor) + (tpmake-flat inner-tensor)) (else (let* ((inner-shape (tp-shape (car lst))) (outer (length lst)) (new-shape (cons outer inner-shape))) - (tpromise inner-flat new-shape))))))) + (tpmake-list->tensor inner-tensor new-shape))))))) (define bounded-idx*^ (λ (shape idx*) @@ -129,8 +131,7 @@ instructions refering to the same gensym variable (lambda (tp i) (cond [(bounded-idx*? tp (list i)) - (tpromise (tcomp-tref tp i) - (cdr (tpromise-shape tp)))] + (tpmake-tref tp i (cdr (tpromise-shape tp)))] [else (error 'exn:tp-tref (string-append "Index out of bounds. ~a " @@ -150,7 +151,7 @@ instructions refering to the same gensym variable (define build-tpromise (λ (s f) - (tpromise (flat:build-tensor s f) s))) + (tpmake-flat (flat:build-tensor s f)))) (define tp-trefs (λ (tp b) @@ -162,9 +163,9 @@ instructions refering to the same gensym variable (error 'tp-trefs "An index was out of bounds")] [else - (tpromise (tcomp-trefs tp b) - `(,(length b) - . ,(cdr (tpromise-shape tp))))]))) + (tpmake-trefs tp b + `(,(length b) + . ,(cdr (tpromise-shape tp))))]))) ;; Default arguments shape-fn and expects-prealloc? need not be passed when f is ;; a function on scalars and doesn't expect a preallocated output vector as its @@ -176,33 +177,20 @@ instructions refering to the same gensym variable [expects-prealloc? #f] [signature (format "~a" f)]) (λ (tp) - (cond - [(scalar? tp) (f tp)] - [(and (tpromise? tp) - (null? (tpromise-shape tp))) - (tpromise - (tcomp-ext1-ρ-scalar f signature tp) - '())] - [expects-prealloc? - (tpromise - (tcomp-ext1-ρ f signature m shape-fn tp) - (merge-shapes - (tp-shape tp) - m - (shape-fn - (min-shape m (tp-shape tp)))))] - [else - (let* ((in-shape (tpromise-shape tp)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-ρ f base-shape out-shape))) - (tpromise - (tcomp-ext1-ρ flat-f signature m shape-fn tp) - (merge-shapes - (tp-shape tp) - m - (shape-fn - (min-shape m (tp-shape tp))))))])))) + (let* ((in-shape (tp-shape tp)) + (base-shape (min-shape m in-shape)) + (shape-fn-out (shape-fn base-shape)) + (out-shape (merge-shapes in-shape m shape-fn-out))) + (cond + [(scalar? tp) (f tp)] + [(and (tpromise? tp) + (null? (tpromise-shape tp))) + (tpmake-ext1-ρ-scalar f signature tp out-shape)] + [expects-prealloc? + (tpmake-ext1-ρ f signature m shape-fn tp out-shape)] + [else + (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out))) + (tpmake-ext1-ρ flat-f signature m shape-fn tp out-shape))]))))) ;; See comment for tp-ext1-ρ (define tp-ext2-ρ @@ -211,23 +199,25 @@ instructions refering to the same gensym variable [expects-prealloc? #f] [signature (format "~a" f)]) (λ (tp-t tp-u) + ;; TODO: Refactor out the code to compute the shape like in tp-ext1-ρ (cond ((and (number? tp-t) (number? tp-u)) (f tp-t tp-u)) [(and (tpromise? tp-t) (tpromise? tp-u) (null? (tpromise-shape tp-t)) (null? (tpromise-shape tp-u))) - (tpromise (tcomp-ext2-ρ-scalar f signature tp-t tp-u) '())] + ;; TODO: Fix the shape being used by using ext2-shapes + (tpmake-ext2-ρ-scalar f signature tp-t tp-u '())] [expects-prealloc? (let* ((s0 (tp-shape tp-t)) (s1 (tp-shape tp-u)) (sf0 (min-shape m s0)) (sf1 (min-shape n s1)) (sf-out (shape-fn sf0 sf1))) - (tpromise - (tcomp-ext2-ρ (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - f signature m n shape-fn) + (tpmake-ext2-ρ + (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + f signature m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))] [else @@ -244,10 +234,10 @@ instructions refering to the same gensym variable t-shape u-shape out-shape))) - (tpromise - (tcomp-ext2-ρ (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - flat-f signature m n shape-fn) + (tpmake-ext2-ρ + (ensure-tpromise tp-t) + (ensure-tpromise tp-u) + flat-f signature m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) @@ -261,20 +251,26 @@ instructions refering to the same gensym variable [expects-prealloc? #f] [signature (format "~a" f)]) (λ (tp zp) + ;; we invoke ensure-tpromise on just zp because it's the result of calling + ;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to + ;; be a tpromise as well, but currently in our workflow we never force tp + ;; before passing it to this function. + ;; (cond ((number? tp) (f tp zp)) (expects-prealloc? - (tpromise - (tcomp-ext1-∇ tp (ensure-tpromise zp) f signature m shape-fn) - (tp-shape tp))) + (tpmake-ext1-∇ (ensure-tpromise tp) + (ensure-tpromise zp) + f signature m shape-fn + (tp-shape tp))) (else (let* ((in-shape (tpromise-shape tp)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tpromise - (tcomp-ext1-∇ tp (ensure-tpromise zp) flat-f signature m shape-fn) - (tp-shape tp)))))))) + (tpmake-ext1-∇ (ensure-tpromise tp) + (ensure-tpromise zp) flat-f signature m shape-fn + (tp-shape tp)))))))) ;; See comment for tp-ext1-ρ (define tp-ext2-∇ @@ -303,21 +299,21 @@ instructions refering to the same gensym variable (ensure-tpromise tp-u) (ensure-tpromise tp-z)))]))))) - (define tp-d-ext2^ (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) - (let* ((out0 'uncalculated) - (out1 'uncalculated)) + (let* ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) + (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) (values - (tpromise (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 0) - (tp-shape tp-t0)) - (tpromise (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 1) - (tp-shape tp-t1)))))) + (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 0 (tp-shape tp-t0)) + (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 1 (tp-shape tp-t1)))))) (define ensure-tpromise (λ (v) (cond - ((scalar? v) (tpromise (ensure-flat v) '())) + ((scalar? v) (tpmake-flat (ensure-flat v))) + ((flat? v) (tpmake-flat v)) (else v)))) (define tp-rank @@ -328,7 +324,7 @@ instructions refering to the same gensym variable (λ (s tp) (cond ((= (flat:size-of s) (flat:size-of (tpromise-shape tp))) - (tpromise (tcomp-reshape s tp) s)) + (tpmake-reshape tp s)) (else (error 'shape-error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) (define tensor? diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt index aad9591..df6575f 100644 --- a/lazy/tensors/1-reflect.rkt +++ b/lazy/tensors/1-reflect.rkt @@ -12,14 +12,24 @@ (define ↓ (lambda (tp) (match tp - [(tpromise v _) - #:when (or (flat:flat? v) (number? v)) + [(tpromise v _ _ _) + #:when (number? v) v] - [(tpromise t _) + [(? tpromise-flat?) + (car (unbox (tpromise-dst tp)))] + [(tpromise t _ _ _) #:when (tcomp? t) - (let-values (((instrs data-segment) (compile-tensor t))) + (let-values (((instrs data-segment) (compile-tensor tp))) (let ((res (interp-racket instrs data-segment))) - (set-tpromise-tensor! tp res) + (cond + ((flat? res) + (set-tpromise-tensor! tp (tcomp-ds-ref #f)) + (set-box! (tpromise-dst tp) (list res)) + (set-box! (tpromise-sign tp) (list #"dsr"))) + ((number? res) + (set-tpromise-tensor! tp res) + (set-box! (tpromise-dst tp) (list)) + (set-box! (tpromise-sign tp) (list #"s" (number->bytes res))))) res))] ;; NOTE: This case runs when we use tp-scalarize to turn ;; the tensor to a scalar diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt index f247b89..450de2f 100644 --- a/lazy/tensors/B-test-programs.rkt +++ b/lazy/tensors/B-test-programs.rkt @@ -137,9 +137,10 @@ (eval-res-1 (flat:tensor 5 6 7 8))) 'tcomp-nested-list->tensor (test-program-data (λ () - (list->tensor `(,(get-test-program 'tensor-r1-0) - ,(get-test-program 'tensor-r1-0) - ,(get-test-program 'tensor-r1-0)))) + (list->tensor + `(,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0) + ,(get-test-program 'tensor-r1-0)))) (eval-res-1 (flat:tensor (flat:tensor 1 2 3) (flat:tensor 1 2 3) @@ -176,6 +177,10 @@ (λ () (id-ρ (sum (tensor 4 5 6)))) (eval-res-1 15)) + 'abs-scalar (test-program-data + (λ () + (abs-ρ (tref (tensor 4 -5 6) 1))) + (eval-res-1 5)) 'sqr (test-program-data (λ () (*-ρ (get-test-program 'build-tensor-r3-0) @@ -344,7 +349,7 @@ (define sum (ext1-ρ sum-f 1 (λ (s) '()) #t)) (define id-f (lambda (v) v)) -(define id-ρ (ext1-ρ id-f 1 (λ (s) s))) +(define id-ρ (ext1-ρ id-f 0 (λ (s) s))) (define *-ρ (ext2-ρ * 0 0)) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt index b566691..e959ccd 100644 --- a/lazy/tensors/c0-ast.rkt +++ b/lazy/tensors/c0-ast.rkt @@ -1,5 +1,7 @@ #lang racket (require "../../flat-tensors/ext-impl.rkt") +(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require file/xxhash32) ;; tensor computations (struct tcomp () #:transparent) @@ -30,21 +32,24 @@ (struct tcomp-ext2-ρ-scalar tcomp (f sign tp-t tp-u) #:transparent) (struct tcomp-ext2-ρ tcomp (tp-t tp-u f sign m n shape-fn) #:transparent) (struct tcomp-ext1-∇ tcomp (tp zp f sign m shape-fn) #:transparent) -(struct tcomp-ext2-∇ tcomp (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) +(struct tcomp-ext2-∇ tcomp (fᵈ + sign r0 r1 shape-fn + tp-t0 tp-t1 tp-z + out-ref0 out-ref1 i) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) (struct tcomp-let tcomp (lhs rhs body) #:transparent) (struct tcomp-var tcomp (name) #:transparent) (struct tcomp-ds-ref tcomp (index) #:transparent) -(struct tpromise ((tensor #:mutable) shape) +(struct tpromise ((tensor #:mutable) shape dst (sign #:mutable)) #:guard - (λ (tensor shape name) - (unless (or (flat? tensor) (tcomp? tensor)) + (λ (tensor shape data-segment-tree signature name) + (unless (or (tcomp? tensor) (number? tensor)) (error 'make-tpromise (string-append "First argument must be either a" - " tcomp or a flat tensor. Got ~a") + " number or a tcomp. Got ~a") tensor)) (unless ((listof positive-integer?) shape) (error 'make-tpromise @@ -52,9 +57,217 @@ "Second argument must be a list" " of positive integers. Got ~a") shape)) - (values tensor shape)) + (unless (and (box? data-segment-tree) + (list? (unbox data-segment-tree))) + (error 'make-tpromise + (string-append + "Third argument must be a box containing a list. Got ~a") + data-segment-tree)) + (unless (and (box? signature) (list? (unbox signature))) + (error 'make-tpromise + (string-append + "Fourth argument must be a box containing a list." + " Got ~a") + signature)) + (values tensor shape data-segment-tree signature)) #:transparent) +(define dst->data-segment + (λ (dst) + (apply vector-append (map dst-member->ds (unbox dst))))) + +(define dst-member->ds + (λ (dstm) + (cond + ((or (eqv? dstm 'uncalculated) + (number? dstm) (flat? dstm)) + (vector dstm)) + ((box? dstm) (dst->data-segment dstm)) + (else (error 'malformed-dst-member "Invalid signature. Got ~a" dstm))))) + +(define gdst-list->tensor + (λ (lst) + (box + (for/list ((l lst) + #:when (tpromise? l)) + (tpromise-dst l))))) + +(define gdst-tref + (λ (tp i) + (box (list (tpromise-dst tp) i)))) + +(define gdst-trefs + (λ (tp i-lst) + (box (list (tpromise-dst tp) (flat:list->tensor i-lst))))) + +(define gdst-ext2-∇ + (λ (tp-t0 tp-t1 tp-z) + (let ((dsn0 (tpromise-dst tp-t0)) + (dsn1 (tpromise-dst tp-t1)) + (dsnz (tpromise-dst tp-z))) + (box (list dsn0 dsn1 dsnz 'uncalculated))))) + +(define sign + (λ (ss) + (let ((xxh32-ctx (make-xxh32))) + (xxh32-reset! xxh32-ctx 0) + (sign-traverse-list! (unbox ss) xxh32-ctx) + (xxh32-digest xxh32-ctx)))) + +(define sign-traverse-list! + (λ (ss ctx) + (for ((s ss)) + (sign-traverse-member! s ctx)))) + +(define sign-traverse-member! + (λ (s ctx) + (cond + ((bytes? s) (xxh32-update! ctx (bytes-append #"_" s))) + ((box? s) (sign-traverse-list! (unbox s) ctx)) + (else (error 'malformed-sign-member "Invalid signature. Got ~a" s))))) + +(define number->bytes + (λ (n) + (string->bytes/utf-8 (number->string n)))) + +(define string->bytes string->bytes/utf-8) + +(define gs-list->tensor + (λ (lst) + (box (list* #"l>t" (map (λ (l) + (cond + ((tpromise? l) (tpromise-sign l)) + ((number? l) (bytes-append #"s_" (number->bytes l))) + (else (error 'gs-list->tensor "Unexpected: ~a" l)))) + lst))))) + +(define gs-tref + (λ (tp) + (box (list #"tr" (tpromise-sign tp) #"dsr")))) + +(define gs-trefs + (λ (tp) + (box (list #"trs" (tpromise-sign tp) #"dsr")))) + +(define gs-ext1-ρ-scalar + (λ (signature tp) + (box (list #"e1rs" (string->bytes signature) (tpromise-sign tp))))) + +(define gs-ext1-ρ + (λ (signature m tp) + (box (list #"e1r" (string->bytes signature) + (number->bytes m) (tpromise-sign tp))))) + +(define gs-ext2-ρ-scalar + (λ (signature tp-t tp-u) + (box (list #"e2rs" (string->bytes signature) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-ext2-ρ + (λ (signature m n tp-t tp-u) + (box (list #"e2r" (string->bytes signature) (number->bytes m) (number->bytes n) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-ext1-∇ + (λ (signature m tp zp) + (box (list #"e1n" (string->bytes signature) (number->bytes m) + (tpromise-sign tp) (tpromise-sign zp))))) + +(define gs-ext2-∇ + (λ (signature r0 r1 tp-t0 tp-t1 tp-z i) + (box (list #"e2n" (string->bytes signature) + (number->bytes r0) (number->bytes r1) + (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) + #"dsr" (number->bytes i))))) + +(define gs-reshape + (λ (shape tp) + (box (list* #"r" (tpromise-sign tp) (map number->bytes shape))))) + +(define tpromise-flat? + (λ (v) + (and (tpromise? v) + (tcomp-ds-ref? (tpromise-tensor v)) + (flat? (car (unbox (tpromise-dst v))))))) + +(define tpmake-flat + (λ (ft) + (tpromise (tcomp-ds-ref #f) (flat-shape ft) + (box (list ft)) (box (list #"dsr"))))) + +(define tpmake-list->tensor + (λ (lst shape) + (let ((tcomp-node (tcomp-list->tensor lst))) + (tpromise tcomp-node shape + (gdst-list->tensor lst) + (gs-list->tensor lst))))) + +;; TODO: Call ensure-promise on the tpromise argument for all tpmake function +(define tpmake-tref + (λ (tp i shape) + (tpromise (tcomp-tref tp (tcomp-ds-ref #f)) + shape + (gdst-tref tp i) + (gs-tref tp)))) + +(define tpmake-trefs + (λ (tp b shape) + (tpromise (tcomp-trefs tp (tcomp-ds-ref #f)) shape + (gdst-trefs tp b) + (gs-trefs tp)))) + +(define tpmake-ext1-ρ-scalar + (λ (f signature tp shape) + (tpromise (tcomp-ext1-ρ-scalar f signature tp) shape + (box (list (tpromise-dst tp))) + (gs-ext1-ρ-scalar signature tp)))) + +(define tpmake-ext1-ρ + (λ (f signature m shape-fn tp shape) + (tpromise (tcomp-ext1-ρ f signature m shape-fn tp) + shape + (box (list (tpromise-dst tp))) + (gs-ext1-ρ signature m tp)))) + +(define tpmake-ext2-ρ-scalar + (λ (f signature tp-t tp-u shape) + (tpromise (tcomp-ext2-ρ-scalar f signature tp-t tp-u) + shape + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-ext2-ρ-scalar signature tp-t tp-u)))) + +(define tpmake-ext2-ρ + (λ (tp-t tp-u f signature m n shape-fn shape) + (tpromise + (tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) + shape + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-ext2-ρ signature m n tp-t tp-u)))) + +(define tpmake-ext1-∇ + (λ (tp zp f signature m shape-fn shape) + (tpromise + (tcomp-ext1-∇ tp zp f signature m shape-fn) + shape + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-ext1-∇ signature m tp zp)))) + +(define tpmake-ext2-∇ + (λ (fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) + (tpromise + (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + shape + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) + (gs-ext2-∇ signature r0 r1 tp-t0 tp-t1 tp-z i)))) + +(define tpmake-reshape + (λ (tp shape) + (tpromise + (tcomp-reshape shape tp) shape + (tpromise-dst tp) + (gs-reshape shape tp)))) + (provide (struct-out tcomp) (struct-out tcomp-list->tensor) (struct-out tcomp-tref) @@ -69,4 +282,19 @@ (struct-out tcomp-let) (struct-out tcomp-var) (struct-out tcomp-ds-ref) - (struct-out tpromise)) + (struct-out tpromise) + dst->data-segment + sign + number->bytes + tpromise-flat? + tpmake-flat + tpmake-list->tensor + tpmake-tref + tpmake-trefs + tpmake-ext1-ρ-scalar + tpmake-ext1-ρ + tpmake-ext2-ρ-scalar + tpmake-ext2-ρ + tpmake-ext1-∇ + tpmake-ext2-∇ + tpmake-reshape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 68e9333..a71cec4 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -4,9 +4,11 @@ (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (struct ext2-∇-result (res) #:mutable #:transparent) +;; TODO: ds-ref is not being used, so remove it +(struct ds-deref (idx) #:transparent) (define ext2-∇-forcer - (λ (fᵈ r0 r1 shape-fn t0 t1 z out0 out1) + (λ (fᵈ r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) (let* ((f0 (ensure-flat t0)) (f1 (ensure-flat t1)) (fz (ensure-flat z)) @@ -55,8 +57,10 @@ vz (+ offz iz) stride-z))) - (data-segment-set! out0 (scalarize (flat s0 g0 0))) - (data-segment-set! out1 (scalarize (flat s1 g1 0))))))))) + (when out-idx0 + (data-segment-set! out-idx0 (scalarize (flat s0 g0 0)))) + (when out-idx1 + (data-segment-set! out-idx1 (scalarize (flat s1 g1 0)))))))))) (define rt:trefs (λ (ft b) @@ -73,7 +77,10 @@ (define data-segment-ref (λ (i) - (vector-ref (data-segment) i))) + (let ((res (vector-ref (data-segment) i))) + (match res + ((ds-deref idx) (vector-ref (data-segment) idx)) + (_ res))))) (define-namespace-anchor a) (define runtime @@ -81,6 +88,7 @@ (namespace-anchor->namespace a)) (provide runtime flat? flat:build-tensor flat:list->tensor - flat:tref rt:trefs (struct-out ext2-∇-result) + flat:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! + (struct-out ds-deref) ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index d1f465d..82c0604 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -77,7 +77,7 @@ (define interp-tensor-expr (λ (t env ds) (match t - [(tpromise tc _) (interp-tensor-expr tc env ds)] + [(tpromise tc _ _ _) (interp-tensor-expr tc env ds)] [v #:when (or (flat? v) (pair? v) (number? v)) v] [(tcomp) (interp-tensor-tcomp t env ds)]))) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 569ed61..c2500ec 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -4,283 +4,201 @@ (require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) (require "../../flat-tensors/ext-impl.rkt") (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) -(require (only-in "c1-racket-runtime.rkt" runtime)) +(require (only-in "c1-racket-runtime.rkt" + runtime ext2-∇-result-res + set-ext2-∇-result-res!)) (require rackunit) -(require file/xxhash32) -(struct counter-data (binding-name - ref-count) +(struct counter-data (binding-name ref-count) #:transparent) ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; ;; Compiler Passes ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; -;;TODO: later eds and gs passes should not be needed because the tcomp AST nodes -;;should have a signature and dss field which will be populated at the time of -;;their instantiation. Then we just access those fields from the AST node rather -;;than computing them. Use a global data segment that has flat tensors used by -;;all tcomp nodes in our program. - -;;Extracts the data segment which is a vector that contains +;; The data segment is a vector that contains ;; * scalars (arguments to tref) ;; * flat tensors ;; * flat tensor¹ of indices that will be the arguments to trefs ;; * the symbol 'uncalculated as an initial placeholder for the output of ;; tcomp-ext2-∇ which will be later replaced by the flat tensor output -;; -(define extract-data-segment - (λ (t) - (let-values (((t^ data-segment-stack) (eds-expr t '()))) - ;; convert data segment stack to data segment array - (values t^ (list->vector (reverse data-segment-stack)))))) - -;; Checks if a member equivalent to v exists in dss using equiv? and based on -;; that returns the dss index where v was inserted and the new dss with -;; insertion as values -;; TODO: Reconsider performance impact of this function -(define insert-unless-exists - (λ (v dss equiv?) - (cond - ((member v dss equiv?) - => (λ (res/rest) - (values (length (cdr res/rest)) dss))) - (else (values (length dss) (cons v dss)))))) - -(define eds-expr - (λ (t dss) - (match t - (s #:when (number? s) - (values s dss)) - (ft - #:when (flat? ft) - (let-values (((idx dss^) (insert-unless-exists ft dss eq?))) - (values (tcomp-ds-ref idx) dss^))) - ((tpromise tc s) - (let-values (((tc^ dss^) (eds-expr tc dss))) - (cond - ((number? tc^) (values tc^ dss^)) - (else (values (tpromise tc^ s) dss^))))) - ((tcomp) (eds-tcomp t dss))))) - -(define eds-tcomp - (λ (tc dss) - (match tc - [(tcomp-list->tensor lst) - (let-values (((ts dss^) - (for/fold ((ts '()) - (dss^ dss)) - ((l lst)) - (let-values (((t dss^^) - (eds-expr l dss^))) - (values (cons t ts) dss^^))))) - (values (tcomp-list->tensor (reverse ts)) dss^))] - - [(tcomp-tref tp i) - (let-values (((t dss^) (eds-expr tp dss))) - (let-values (((idx dss^^) (insert-unless-exists i dss^ eqv?))) - (values (tcomp-tref t (tcomp-ds-ref idx)) dss^^)))] - [(tcomp-trefs tp i-list) - (let-values (((t dss^) (eds-expr tp dss))) - (let-values (((idx dss^^) - ;; Comparison by flat:tensor-equal? is okay because - ;; members of b are integers (not reals) and their - ;; equality is checked without a tolerance. - ;; TODO: Reconsider performance impact of flat:list->tensor. - ;; Maybe memoize it. - (insert-unless-exists (flat:list->tensor i-list) - dss^ - flat:tensor-equal?))) - (values (tcomp-trefs t (tcomp-ds-ref idx)) dss^^)))] - [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let-values (((t0 dss^) (eds-expr tp-t0 dss))) - (let-values (((t1 dss^^) (eds-expr tp-t1 dss^))) - (let-values (((z dss^^^) (eds-expr tp-z dss^^))) - (values (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn t0 t1 z - (length dss^^^) - (add1 (length dss^^^)) i) - (cons out1 (cons out0 dss^^^))))))] - [(tcomp-ext1-∇ tp zp f signature m shape-fn) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (let-values (((zp^ dss^^) (eds-expr zp dss^))) - (values (tcomp-ext1-∇ tp^ zp^ f signature m shape-fn) - dss^^)))] - [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) - (let-values (((t dss^) (eds-expr tp-t dss))) - (let-values (((u dss^^) (eds-expr tp-u dss^))) - (values (tcomp-ext2-ρ-scalar f signature t u) dss^^)))] - [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) - (let-values (((t dss^) (eds-expr tp-t dss))) - (let-values (((u dss^^) (eds-expr tp-u dss^))) - (values (tcomp-ext2-ρ t u f signature m n shape-fn) dss^^)))] - [(tcomp-ext1-ρ-scalar f signature tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-ext1-ρ-scalar f signature tp^) dss^))] - [(tcomp-ext1-ρ f signature m shape-fn tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-ext1-ρ f signature m shape-fn tp^) dss^))] - [(tcomp-reshape s tp) - (let-values (((tp^ dss^) (eds-expr tp dss))) - (values (tcomp-reshape s tp^) dss^))]))) - -(define hash-signatures? - (make-parameter #t)) - -(define sign - (let ((xxh32-ctx (make-xxh32))) - (λ ss - (cond - ((hash-signatures?) - (xxh32-reset! xxh32-ctx 0) - (xxh32-update! xxh32-ctx (bytes-join ss #"_")) - (number->bytes (xxh32-digest xxh32-ctx))) - (else (format "~a" ss)))))) - -(define number->bytes - (λ (n) - (string->bytes/utf-8 (number->string n)))) - -(define string->bytes string->bytes/utf-8) - -(define generate-signature - (λ (t) - (gs-expr t))) -(define gs-expr +(define generate-ds-refs (λ (t) - (match t - (s #:when (number? s) - (sign #"s~a" (number->bytes s))) - ((tpromise tc _) (gs-expr tc)) - ((tcomp) (gs-tcomp t))))) - -(define gs-tcomp - (λ (tc) + (let-values (((t^ ref) (gdr-tpromise t 0))) + t^))) + +(define gdr-tpromise + (λ (tp ref) + (match tp + ((tpromise tc s dss sign) + (let-values (((tc^ ref^) (gdr-tcomp tc ref))) + (values (tpromise tc^ s dss sign) ref^)))))) + +(define gdr-tcomp + (λ (tc ref) (match tc + ((? number?) (values tc ref)) [(tcomp-list->tensor lst) - (apply sign #"l>t" (map gs-expr lst))] - [(tcomp-tref tp i) - (sign #"tr" (gs-expr tp) (gs-expr i))] - [(tcomp-trefs tp b) - (sign #"trs" (gs-expr tp) (gs-expr b))] - [(tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (sign #"e2n" (string->bytes signature) - (number->bytes r0) (number->bytes r1) - (gs-expr tp-t0) (gs-expr tp-t1) (gs-expr tp-z) - (number->bytes out0) (number->bytes out1) (number->bytes i))] - [(tcomp-ext1-∇ tp zp f signature m shape-fn) - (sign #"e1n" (string->bytes signature) (number->bytes m) - (gs-expr tp) (gs-expr zp))] - [(tcomp-ext2-ρ-scalar f signature tp-t tp-u) - (sign #"e2rs" (string->bytes signature) (gs-expr tp-t) (gs-expr tp-u))] - [(tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) - (sign #"e2r" (string->bytes signature) (number->bytes m) (number->bytes m) - (gs-expr tp-t) (gs-expr tp-u))] - [(tcomp-ext1-ρ-scalar f signature tp) - (sign #"e1rs" (string->bytes signature) (gs-expr tp))] - [(tcomp-ext1-ρ f signature m shape-fn tp) - (sign #"e1r" (string->bytes signature) (number->bytes m) (gs-expr tp))] + (for/fold + ((tcs '()) + (ref^ ref) + #:result (values (tcomp-list->tensor (reverse tcs)) ref^)) + ((l lst)) + (let-values (((tc ref^^) + (cond + ((tpromise? l) (gdr-tpromise l ref^)) + ((number? l) (values l ref^)) + (else (error 'gdr-list->tensor "Unexpected: ~a" l))))) + (values (cons tc tcs) ref^^)))] + [(tcomp-tref tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref))) + (values (tcomp-tref tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-trefs tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref))) + (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + out-ref0 + out-ref1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^))) + (cond + ((and (eqv? i 0) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref0)))) + (set-ext2-∇-result-res! out-ref0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) + (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ + out-ref0 out-ref1 i) + (add1 ref^^^)))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref)) + ((zp^ ref^^) (gdr-tpromise zp ref^))) + (values (tcomp-ext1-∇ tp^ zp^ f sign m shape-fn) ref^^))] + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^))) + (values (tcomp-ext2-ρ-scalar f sign tp-t^ tp-u^) ref^^))] + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^))) + (values (tcomp-ext2-ρ tp-t^ tp-u^ f sign m n shape-fn) ref^^))] + [(tcomp-ext1-ρ-scalar f sign tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref))) + (values (tcomp-ext1-ρ-scalar f sign tp^) ref^))] + [(tcomp-ext1-ρ f sign m shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref))) + (values (tcomp-ext1-ρ f sign m shape-fn tp^) ref^))] [(tcomp-reshape s tp) - (apply sign #"r" (gs-expr tp) (map number->bytes s))] - [(tcomp-ds-ref index) - (sign #"dsr" (number->bytes index))]))) + (let-values (((tp^ ref^) (gdr-tpromise tp ref))) + (values (tcomp-reshape s tp^) ref^))] + [(tcomp-ds-ref #f) (values (tcomp-ds-ref ref) (add1 ref))] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (let*-values (((rhs^ ref^) (gdr-tpromise rhs ref)) + ((body^ ref^^) (gdr-tpromise body ref^))) + (values (tcomp-let lhs rhs^ body^) ref^^))] + [(tcomp-var name) (values (tcomp-var name) ref)]))) ;; Count references so that the tcomp AST nodes that refer to the same memory ;; location i.e. common AST nodes get extracted by let-binding them in the ;; compiled output racket code. (define count-references (λ (t) - (let-values (((counter uid) (count-references-expr t (hasheq) 0))) + (let-values (((counter uid) (cr-tpromise t (hasheq) 0))) counter))) -(define count-references-expr +;; TODO: Try using the signature field of tpromise struct as keys instead tcomp +;; references +(define cr-tpromise (λ (t counter uid) (match t - ((tpromise tc _) - (count-references-expr tc counter uid)) - ((tcomp) (count-references-tcomp t counter uid)) - (_ (values counter uid))))) + ((tpromise tc _ _ _) + (cr-tcomp tc counter uid))))) -(define count-references-tcomp +(define cr-tcomp (λ (tc counter uid) - (match-let (((counter-data tc-binding-name tc-ref-count) - (hash-ref counter tc - (λ () - (let-values (((st _) (struct-info tc))) - (let-values (((tcomp-name _0 _1 _2 _3 _4 _5 _6) - (struct-type-info st))) - (counter-data (string->symbol - (format "~a_~a" tcomp-name uid)) - 0))))))) - (let* ((new-count (add1 tc-ref-count)) - (counter^ (hash-set counter tc - (counter-data tc-binding-name - new-count))) - (uid^ (add1 uid))) - (cond - ((> new-count 1) - ;; No need to increase reference count of children if parent occurs - ;; more than once. This helps avoid creating extra tcom-var later in - ;; ecs if child tcomp occurs only once within the parent tcomp, but - ;; the parent tcomp itself occurs more than once. - (values counter^ uid^)) - (else - (match tc - [(tcomp-list->tensor lst) - (for/fold - ((counter^^ counter^) - (uid^^ uid^)) - ((l lst)) - (count-references-expr l counter^^ uid^^))] - [(tcomp-tref tp i) - (count-references-expr tp counter^ uid^)] - [(tcomp-trefs tp b) - (count-references-expr tp counter^ uid^)] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let*-values (((counter-1 uid-1) (count-references-expr tp-t0 counter^ uid^)) - ((counter-2 uid-2) (count-references-expr tp-z counter-1 uid-1))) - (count-references-expr tp-t1 counter-2 uid-2))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) - (let-values (((counter-1 uid-1) (count-references-expr tp counter^ uid^))) - (count-references-expr zp counter-1 uid-1))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) - (let-values (((counter-1 uid-1) (count-references-expr tp-t counter^ uid^))) - (count-references-expr tp-u counter-1 uid-1))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) - (let-values (((counter-1 uid-1) (count-references-expr tp-t counter^ uid^))) - (count-references-expr tp-u counter-1 uid-1))] - [(tcomp-ext1-ρ-scalar f sign tp) - (count-references-expr tp counter^ uid^)] - [(tcomp-ext1-ρ f sign m shape-fn tp) - (count-references-expr tp counter^ uid^)] - [(tcomp-reshape s tp) - (count-references-expr tp counter^ uid^)] - [(tcomp-ds-ref index) (values counter^ uid^)] - ;;need these cases for testing compiler invariant - [(tcomp-let lhs rhs body) - (let-values (((counter-1 uid-1) (count-references-expr rhs counter^ uid^))) - (count-references-expr body counter-1 uid-1))] - [(tcomp-var name) (values counter^ uid^)]))))))) + (cond + ((number? tc) (values counter uid)) + (else + (match-let (((counter-data tc-binding-name tc-ref-count) + (hash-ref counter tc + (λ () + (let-values (((st _) (struct-info tc))) + (let-values (((tcomp-name _0 _1 _2 _3 _4 _5 _6) + (struct-type-info st))) + (counter-data (string->symbol + (format "~a_~a" tcomp-name uid)) + 0))))))) + (let* ((new-count (add1 tc-ref-count)) + (counter^ (hash-set counter tc + (counter-data tc-binding-name + new-count))) + (uid^ (add1 uid))) + (cond + ((> new-count 1) + ;; No need to increase reference count of children if parent occurs + ;; more than once. This helps avoid creating extra tcom-var later in + ;; ecs if child tcomp occurs only once within the parent tcomp, but + ;; the parent tcomp itself occurs more than once. + (values counter^ uid^)) + (else + (match tc + [(tcomp-list->tensor lst) + (for/fold + ((counter^^ counter^) + (uid^^ uid^)) + ((l lst)) + (cond + ((tpromise? l) (cr-tpromise l counter^^ uid^^)) + ((number? l) (values counter^^ uid^^)) + (else (error 'cr-list->tensor "Unexpected: ~a" l))))] + [(tcomp-tref tp i) + (cr-tpromise tp counter^ uid^)] + [(tcomp-trefs tp b) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) + ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) + (cr-tpromise tp-t1 counter-2 uid-2))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) + (cr-tpromise zp counter-1 uid-1))] + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-ext1-ρ-scalar f sign tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ext1-ρ f sign m shape-fn tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-reshape s tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-ds-ref index) (values counter^ uid^)] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (let-values (((counter-1 uid-1) (cr-tpromise rhs counter^ uid^))) + (cr-tpromise body counter-1 uid-1))] + [(tcomp-var name) (values counter^ uid^)]))))))))) (define extract-common-subexpressions (λ (t counter) (let-values (((instrs bindings) - (run-compiler-ecs (ecs-expr t counter) '()))) + (run-compiler-ecs (ecs-tpromise t counter) '()))) (for/fold ((body instrs)) ((binding bindings)) (tcomp-let (car binding) (cdr binding) body))))) -(define ecs-expr +(define ecs-tpromise (λ (tc counter) (match tc - [(tpromise tc s) + [(tpromise tc s dss sign) (->ecs - (ecs-expr tc counter) + (ecs-tcomp tc counter) (λ (instrs) - (inj-ecs-val (tpromise instrs s))))] - [tc #:when (number? tc) - (inj-ecs-val tc)] - [(tcomp) (ecs-tcomp tc counter)]))) + (inj-ecs-val (tpromise instrs s dss sign))))]))) (define ecs-tcomp (λ (tc counter) @@ -289,13 +207,18 @@ (λ () (counter-data (gensym 'illegal) 0))))) (match tc + [tc #:when (number? tc) + (inj-ecs-val tc)] [(tcomp-list->tensor lst) (let ((instrs-list-compiler (for/foldr ((list-compiler (inj-ecs-val '()))) ((arg lst)) (->ecs - (ecs-expr arg counter) + (cond + ((tpromise? arg) (ecs-tpromise arg counter)) + ((number? arg) (inj-ecs-val arg)) + (else (error 'ecs-list->tensor "Unexpected: ~a" arg))) (λ (instrs) (->ecs list-compiler @@ -307,23 +230,23 @@ (inj-ecs-tcomp (tcomp-list->tensor instrs-list) tc-counter-data))))] [(tcomp-tref tp i) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-tref instrs i) tc-counter-data)))] [(tcomp-trefs tp b) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (->ecs - (ecs-expr tp-t0 counter) + (ecs-tpromise tp-t0 counter) (λ (t0-instrs) (->ecs - (ecs-expr tp-t1 counter) + (ecs-tpromise tp-t1 counter) (λ (t1-instrs) (->ecs - (ecs-expr tp-z counter) + (ecs-tpromise tp-z counter) (λ (z-instrs) (inj-ecs-tcomp (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn @@ -332,47 +255,47 @@ tc-counter-data)))))))] [(tcomp-ext1-∇ tp zp f sign m shape-fn) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (t-instrs) (->ecs - (ecs-expr zp counter) + (ecs-tpromise zp counter) (λ (z-instrs) (inj-ecs-tcomp (tcomp-ext1-∇ t-instrs z-instrs f sign m shape-fn) tc-counter-data)))))] [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) (->ecs - (ecs-expr tp-t counter) + (ecs-tpromise tp-t counter) (λ (t-instrs) (->ecs - (ecs-expr tp-u counter) + (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp (tcomp-ext2-ρ-scalar f sign t-instrs u-instrs) tc-counter-data)))))] [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) (->ecs - (ecs-expr tp-t counter) + (ecs-tpromise tp-t counter) (λ (t-instrs) (->ecs - (ecs-expr tp-u counter) + (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp (tcomp-ext2-ρ t-instrs u-instrs f sign m n shape-fn) tc-counter-data)))))] [(tcomp-ext1-ρ-scalar f sign tp) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f sign instrs) tc-counter-data)))] [(tcomp-ext1-ρ f sign m shape-fn tp) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-ext1-ρ f sign m shape-fn instrs) tc-counter-data)))] [(tcomp-reshape s tp) (->ecs - (ecs-expr tp counter) + (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-reshape s instrs) tc-counter-data)))] [(tcomp-ds-ref index) (inj-ecs-tcomp tc tc-counter-data)])))) @@ -422,77 +345,83 @@ (define generate-racket (λ (t) - (gr-expr t))) + (gr-tpromise t))) -(define gr-expr +(define gr-tpromise (λ (t) (match t - [(tpromise tc _) (gr-expr tc)] - [v #:when (number? v) v] - [(tcomp) (gr-tcomp t)]))) + [(tpromise tc _ _ _) (gr-tcomp tc)]))) (define gr-tcomp (λ (tc) (match tc + [v #:when (number? v) v] [(tcomp-list->tensor lst) - (let ((instrs-list (map (λ (t) (gr-expr t)) lst))) + (let ((instrs-list (map (λ (t) + (cond + ((tpromise? t) (gr-tpromise t)) + ((number? t) t) + (else (error 'gr-list->tensor "Unexpected: ~a" t)))) + lst))) `(flat:list->tensor (list ,@instrs-list)))] [(tcomp-tref tp i) - (let ((instrs (gr-expr tp)) - (i-instrs (gr-expr i))) + (let ((instrs (gr-tpromise tp)) + (i-instrs (gr-tcomp i))) `(flat:tref ,instrs ,i-instrs))] [(tcomp-trefs tp b) - (let ((instrs (gr-expr tp)) - (b-instrs (gr-expr b))) + (let ((instrs (gr-tpromise tp)) + (b-instrs (gr-tcomp b))) `(rt:trefs ,instrs ,b-instrs))] [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - (let ((t0-instrs (gr-expr tp-t0)) - (t1-instrs (gr-expr tp-t1)) - (z-instrs (gr-expr tp-z))) - (let ((b (if (zero? i) out0 out1))) - `(let* ([b ,b] - [v (data-segment-ref b)]) + (let ((t0-instrs (gr-tpromise tp-t0)) + (t1-instrs (gr-tpromise tp-t1)) + (z-instrs (gr-tpromise tp-z)) + (out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (let ((index (if (zero? i) out-idx0 out-idx1))) + `(let* ([index ,index] + [v (data-segment-ref index)]) (cond ((eqv? v 'uncalculated) (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn ,t0-instrs ,t1-instrs - ,z-instrs ,out0 ,out1) - (data-segment-ref b)) + ,z-instrs ,out-idx0 ,out-idx1) + (data-segment-ref index)) (else v)))))] [(tcomp-ext1-∇ tp zp f sign m shape-fn) - (let ((t-instrs (gr-expr tp)) - (z-instrs (gr-expr zp))) + (let ((t-instrs (gr-tpromise tp)) + (z-instrs (gr-tpromise zp))) `(scalarize (flat-ext1-∇ ,f ,m ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,z-instrs))))] [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) - (let ((t-instrs (gr-expr tp-t)) - (u-instrs (gr-expr tp-u))) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) `(,f ,t-instrs ,u-instrs))] [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) - (let ((t-instrs (gr-expr tp-t)) - (u-instrs (gr-expr tp-u))) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) `(scalarize (flat-ext2-ρ ,f ,m ,n ,shape-fn (ensure-flat ,t-instrs) (ensure-flat ,u-instrs))))] [(tcomp-ext1-ρ-scalar f sign tp) - (let ((instrs (gr-expr tp))) + (let ((instrs (gr-tpromise tp))) `(,f ,instrs))] [(tcomp-ext1-ρ f sign m shape-fn tp) - (let ((instrs (gr-expr tp))) + (let ((instrs (gr-tpromise tp))) `(scalarize (flat-ext1-ρ ,f ,m ,shape-fn (ensure-flat ,instrs))))] [(tcomp-reshape s tp) - (let ((instrs (gr-expr tp))) + (let ((instrs (gr-tpromise tp))) `(flat ',s (flat-store ,instrs) (flat-offset ,instrs)))] [(tcomp-let lhs rhs body) - (let ((rhs-instrs (gr-expr rhs)) - (body-instrs (gr-expr body))) + (let ((rhs-instrs (gr-tpromise rhs)) + (body-instrs (gr-tpromise body))) `(let ((,lhs ,rhs-instrs)) ,body-instrs))] [(tcomp-var name) name] @@ -516,26 +445,27 @@ (define compile-tensor (λ (t) (display-compiler-trace 'Source-Tensor t) - (let-values (((eds-instrs ds) (extract-data-segment t))) - (display-compiler-trace 'Extract-Data-Segment-data ds) - (display-compiler-trace 'Extract-Data-Segment-instructions eds-instrs) - (let ((signature (generate-signature eds-instrs))) - (display-compiler-trace 'Generate-Signature signature) + (let ((ds (dst->data-segment (tpromise-dst t)))) + (display-compiler-trace 'Data-Segment ds) + (let ((signature (sign (tpromise-sign t)))) + (display-compiler-trace 'Signature signature) (cond ((hash-has-key? (cache) signature) (let ((compiled (hash-ref (cache) signature))) (display-compiler-trace 'Cache-Hit signature) (values compiled ds))) (else - (let ((counter (count-references eds-instrs))) - (display-compiler-trace 'Count-References counter) - (let ((extracted (extract-common-subexpressions eds-instrs counter))) - (display-compiler-trace 'Extract-Common-Subexpressions extracted) - (let* ((gr (generate-racket extracted)) - (rkt (compile-racket gr))) - (display-compiler-trace 'Generate-Racket gr) - (hash-set! (cache) signature rkt) - (values rkt ds)))))))))) + (let ((instrs-dsr (generate-ds-refs t))) + (display-compiler-trace 'Generate-DS-Refs instrs-dsr) + (let ((counter (count-references instrs-dsr))) + (display-compiler-trace 'Count-References counter) + (let ((extracted (extract-common-subexpressions instrs-dsr counter))) + (display-compiler-trace 'Extract-Common-Subexpressions extracted) + (let* ((gr (generate-racket extracted)) + (rkt (compile-racket gr))) + (display-compiler-trace 'Generate-Racket gr) + (hash-set! (cache) signature rkt) + (values rkt ds))))))))))) (define compile-racket (λ (r) @@ -545,6 +475,8 @@ ;;TODO: update this for new compiler passes (define compile-tensor/checks (λ (t) + t + #; (let-values (((eds-instrs ds) (extract-data-segment t))) (flat:check-tensor-equal? (interp-tensor t) (interp-tensor eds-instrs)) (let ((counter (count-references t))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt index c5f1d78..ce4481e 100644 --- a/lazy/tensors/test/test-1-reflect.rkt +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -1,35 +1,12 @@ (module+ test (require rackunit) - (require (only-in "c3-compiler.rkt" - compile-tensor/checks)) (require "0-lazy.rkt") (require "B-test-programs.rkt") - ;; TODO: Add a comment above each test case describing what the test case is testing - (define-check (check-compiler-invariants tp) - (let-values (((instrs ds) (compile-tensor tp))) - (with-check-info - (('data-segment ds) - ('instrs instrs)) - 'ok - #; - (for ((name/flat ds)) - (unless (and (flat:flat? (cdr name/flat)) - (not (null? (flat:flat-shape (cdr name/flat))))) - (fail-check (format (string-append "Value associated with the variable" - " ~a should be a flat tensor. " - "Associated value found: ~a") - (car name/flat) (cdr name/flat))))) - #; - (define unique-flats (list->seteq (map cdr ds))) - #; - (unless (equal? (set-count unique-flats) - (length (filter flat? (map cdr ds)))) - (fail-check (string-append "Duplicate flat tensors found" - " in data segment. Variables in data segment" - " should be paired with unique" - " flat tensors")))))) - ;;TODO: Move all check-compiler-invariant checks to the test file for - ;;c3-compiler.rkt file. + + (define evaluated-tpromise? + (λ (tp) + (or (tpromise-flat? tp) + (number? (tpromise-tensor tp))))) (for (((test-name test-data) (in-hash test-programs))) (match-define (test-program-data th res) test-data) @@ -41,7 +18,7 @@ forced res (format "Expected result doesn't match in test case ~a" test-name)) - (check-false (tcomp? (tpromise-tensor tp))) + (check-pred evaluated-tpromise? tp) (check-equal? (tpromise-shape tp) (flat:shape forced)))) ((eval-res-2 res1 res2) (let*-values (((tp1 tp2) (th)) @@ -51,18 +28,19 @@ forced1 res1 (format "Expected first result doesn't match in test case ~a" test-name)) - (check-false (tcomp? (tpromise-tensor tp1))) + (check-pred evaluated-tpromise? tp1) (check-equal? (tpromise-shape tp1) (flat:shape forced1)) (flat:check-tensor-equal? forced2 res2 (format "Expected second result doesn't match in test case ~a" test-name)) - (check-false (tcomp? (tpromise-tensor tp2))) + (check-pred evaluated-tpromise? tp2) (check-equal? (tpromise-shape tp2) (flat:shape forced2)))))) (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) - (check-true (flat:flat? (tpromise-tensor test-tensor-r1-0))) + (check-false (flat:flat? (tpromise-tensor test-tensor-r1-0))) + (check-true (flat:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) @@ -97,6 +75,31 @@ (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) (↓ (tensor 1 2 3))) + (define test-id-scalar (get-test-program 'id-scalar)) + (define test-force-scalar + (+-ρ test-id-scalar + (get-test-program 'sum-nested))) + (void (↓ test-id-scalar)) + (flat:check-tensor-equal? (↓ test-force-scalar) + (↓ (tensor 19 21 20))) + + (define test-force-subexpr + (+-ρ (get-test-program 'id-scalar) + (get-test-program 'sum-nested))) + (define test-force-mutate + (+-ρ test-force-subexpr + (+-ρ (get-test-program 'sum-nested) + (get-test-program 'sum-nested)))) + (void (↓ test-force-subexpr)) + (flat:check-tensor-equal? (↓ test-force-mutate) + (↓ (tensor 27 33 30))) + + (define test-tp-r1 (tensor -1 -2 -3)) + (define test-force-supexpr (abs-ρ test-tp-r1)) + (void (↓ test-force-supexpr)) + (flat:check-tensor-equal? (↓ test-tp-r1) + (↓ (tensor -1 -2 -3))) + (define test-trefs (get-test-program 'tcomp-trefs)) (check-true (tcomp? (tpromise-tensor test-trefs))) (check-exn exn:fail? (λ () (trefs test-nested-tensor '(0 4)))) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index 76c49cb..96ab6f1 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -3,35 +3,57 @@ (require "B-test-programs.rkt") (require "0-lazy.rkt") + (define current-test-program-name (make-parameter #f)) + (define-check (check-compiler-invariants tp) + (let-values (((instrs ds) (compile-tensor tp))) + (with-check-info + (('data-segment ds) + ('instrs instrs)) + (define test-name-string + (cond + ((current-test-program-name) (format "In test case: ~a" + (current-test-program-name))) + (else ""))) + 'ok + ;;TODO: Add a check to ensure the number of tcomp-ds-ref occurring in + ;;the input equals the size of the data segment + #; + (for ((name/flat ds)) + (unless (and (flat:flat? (cdr name/flat)) + (not (null? (flat:flat-shape (cdr name/flat))))) + (fail-check (format (string-append "Value associated with the variable" + " ~a should be a flat tensor. " + "Associated value found: ~a") + (car name/flat) (cdr name/flat)))))))) + + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (parameterize ((current-test-program-name test-name)) + (match res + ((eval-res-1 res) + (let* ((tp (th))) + (check-compiler-invariants tp))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th))) + (check-compiler-invariants tp1) + (check-compiler-invariants tp2)))))) (define-check (check-signatures-equal? t1 t2) - (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) - ((eds-instrs-2 ds2) (extract-data-segment t2))) - (let ((sig1 (generate-signature eds-instrs-1)) - (sig2 (generate-signature eds-instrs-2))) - (with-check-info - (('extracted-instrs-1 eds-instrs-1) - ('extracted-instrs-2 eds-instrs-2) - ('data-segment-1 ds1) - ('data-segment-2 ds2) - ('signature-1 sig1) + (let ((sig1 (tpromise-sign t1)) + (sig2 (tpromise-sign t2))) + (with-check-info + (('signature-1 sig1) ('signature-2 sig2)) - (unless (equal? sig1 sig2) - (fail-check "signature mismatch")))))) + (unless (equal? sig1 sig2) + (fail-check "signature mismatch"))))) (define-check (check-signatures-not-equal? t1 t2) - (let-values (((eds-instrs-1 ds1) (extract-data-segment t1)) - ((eds-instrs-2 ds2) (extract-data-segment t2))) - (let ((sig1 (generate-signature eds-instrs-1)) - (sig2 (generate-signature eds-instrs-2))) - (with-check-info - (('extracted-instrs-1 eds-instrs-1) - ('extracted-instrs-2 eds-instrs-2) - ('data-segment-1 ds1) - ('data-segment-2 ds2) - ('signature-1 sig1) + (let ((sig1 (tpromise-sign t1)) + (sig2 (tpromise-sign t2))) + (with-check-info + (('signature-1 sig1) ('signature-2 sig2)) - (when (equal? sig1 sig2) - (fail-check "signatures musn't match")))))) + (when (equal? sig1 sig2) + (fail-check "signatures musn't match"))))) (define test-tensor-r1-1 (get-test-program 'tensor-r1-1)) (define test-tcomp-tref (get-test-program 'tcomp-tref)) @@ -42,7 +64,6 @@ (define tensor-r1 (get-test-program 'tensor-r1-0)) (check-signatures-equal? (*-ρ 2 tensor-r1) (*-ρ 3 tensor-r1)) - (check-signatures-not-equal? (*-ρ 2 3) (*-ρ 3 3)) (define v^ (random-tensor (list 10 4))) (define r^ (random-tensor (list 10 4 2))) @@ -75,17 +96,17 @@ (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-tref)))) (check-pred (λ (ds) - (eqv? (vector-length ds) 2)) + (eqv? (set-count (list->seteq (vector->list ds))) 2)) ds - (string-append "Tensors and tref indices occurring multiple times in" - " source AST but referring to the same tensor AST node must" - " be added to the data segment only once."))) + (string-append "eq? equivalent flat tensors and tref indices" + " used to construct the source AST must" + " be eq? equivalent in the data segment as well."))) (let-values (((rkt ds) (compile-tensor (get-test-program 'extract-ds-once-trefs)))) (check-pred (λ (ds) - (eqv? (vector-length ds) 2)) + (eqv? (set-count (list->seteq (vector->list ds))) 2)) ds - (string-append "Tensors and trefs index lists occurring multiple times in" - " source AST but pointing to the same tensor AST node must" - " be added to the data segment only once."))) + (string-append "eq? equivalent flat tensors and trefs index lists" + " used to construct the source AST must" + " be eq? equivalent in the data segment as well."))) ) From 7c67bc777a7100c4be6c74d473c72a8c8acf4760 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 23 Dec 2023 00:01:04 -0500 Subject: [PATCH 73/83] [add-lazy]Remove ds-ref because it wasn't used --- lazy/tensors/c1-racket-runtime.rkt | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index a71cec4..a882e68 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -4,8 +4,6 @@ (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (struct ext2-∇-result (res) #:mutable #:transparent) -;; TODO: ds-ref is not being used, so remove it -(struct ds-deref (idx) #:transparent) (define ext2-∇-forcer (λ (fᵈ r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) @@ -77,10 +75,7 @@ (define data-segment-ref (λ (i) - (let ((res (vector-ref (data-segment) i))) - (match res - ((ds-deref idx) (vector-ref (data-segment) idx)) - (_ res))))) + (vector-ref (data-segment) i))) (define-namespace-anchor a) (define runtime @@ -89,6 +84,5 @@ (provide runtime flat? flat:build-tensor flat:list->tensor flat:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! - (struct-out ds-deref) ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment) From be328f2f50b31475f704b26a09af62d3b3282506 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 23 Dec 2023 00:06:33 -0500 Subject: [PATCH 74/83] =?UTF-8?q?[add-lazy]Remove=20redundant=20=E2=86=93?= =?UTF-8?q?=20in=20the=20definition=20of=20=CF=81-=E2=88=87-checker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lazy/autodiff/D-test-helpers.rkt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index caac553..dde0e59 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -7,9 +7,8 @@ (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) - ;; TODO: This code ahould work even after removing the ↓ call - (let* ((y (↓ (apply fn args))) - (g (↓ (apply (∇¹ fn) args))) + (let* ((y (apply fn args)) + (g (apply (∇¹ fn) args)) (ans-ρ (ρ ans))) (cond ((and (equal-wt? ans-ρ (ρ y)) From 5c2d51f8170882d0bc1f40a4a79d9687e9ac0782 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Wed, 27 Dec 2023 22:19:14 -0500 Subject: [PATCH 75/83] [add-lazy]Move ensure-tpromise calls in tpmake-* functions --- lazy/tensors/0-lazy.rkt | 36 +++++----------------------- lazy/tensors/c0-ast.rkt | 53 +++++++++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 48 deletions(-) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index b0afa55..5fbcf2f 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -215,8 +215,7 @@ instructions refering to the same gensym variable (sf1 (min-shape n s1)) (sf-out (shape-fn sf0 sf1))) (tpmake-ext2-ρ - (ensure-tpromise tp-t) - (ensure-tpromise tp-u) + tp-t tp-u f signature m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))] @@ -235,8 +234,7 @@ instructions refering to the same gensym variable u-shape out-shape))) (tpmake-ext2-ρ - (ensure-tpromise tp-t) - (ensure-tpromise tp-u) + tp-t tp-u flat-f signature m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))) @@ -251,26 +249,17 @@ instructions refering to the same gensym variable [expects-prealloc? #f] [signature (format "~a" f)]) (λ (tp zp) - ;; we invoke ensure-tpromise on just zp because it's the result of calling - ;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to - ;; be a tpromise as well, but currently in our workflow we never force tp - ;; before passing it to this function. ;; (cond ((number? tp) (f tp zp)) (expects-prealloc? - (tpmake-ext1-∇ (ensure-tpromise tp) - (ensure-tpromise zp) - f signature m shape-fn - (tp-shape tp))) + (tpmake-ext1-∇ tp zp f signature m shape-fn (tp-shape tp))) (else (let* ((in-shape (tpromise-shape tp)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tpmake-ext1-∇ (ensure-tpromise tp) - (ensure-tpromise zp) flat-f signature m shape-fn - (tp-shape tp)))))))) + (tpmake-ext1-∇ tp zp flat-f signature m shape-fn (tp-shape tp)))))))) ;; See comment for tp-ext1-ρ (define tp-ext2-∇ @@ -285,19 +274,13 @@ instructions refering to the same gensym variable (λ (tp-t tp-u tp-z) (cond (expects-prealloc? - (tp-f f - (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - (ensure-tpromise tp-z))) + (tp-f f tp-t tp-u tp-z)) [else (let* ((t-shape (min-shape m (tp-shape tp-t))) (u-shape (min-shape n (tp-shape tp-u))) (out-shape (shape-fn t-shape u-shape)) (flat-f (functional->preallocated-2-∇ f t-shape u-shape out-shape))) - (tp-f flat-f - (ensure-tpromise tp-t) - (ensure-tpromise tp-u) - (ensure-tpromise tp-z)))]))))) + (tp-f flat-f tp-t tp-u tp-z))]))))) (define tp-d-ext2^ (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) @@ -309,13 +292,6 @@ instructions refering to the same gensym variable (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 1 (tp-shape tp-t1)))))) -(define ensure-tpromise - (λ (v) - (cond - ((scalar? v) (tpmake-flat (ensure-flat v))) - ((flat? v) (tpmake-flat v)) - (else v)))) - (define tp-rank (λ (tp) (flat:len (tp-shape tp)))) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt index e959ccd..70f2ae0 100644 --- a/lazy/tensors/c0-ast.rkt +++ b/lazy/tensors/c0-ast.rkt @@ -202,7 +202,6 @@ (gdst-list->tensor lst) (gs-list->tensor lst))))) -;; TODO: Call ensure-promise on the tpromise argument for all tpmake function (define tpmake-tref (λ (tp i shape) (tpromise (tcomp-tref tp (tcomp-ds-ref #f)) @@ -236,30 +235,48 @@ (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) (gs-ext2-ρ-scalar signature tp-t tp-u)))) +(define ensure-tpromise + (λ (v) + (cond + ((number? v) (tpmake-flat (ensure-flat v))) + ((flat? v) (tpmake-flat v)) + (else v)))) + (define tpmake-ext2-ρ (λ (tp-t tp-u f signature m n shape-fn shape) - (tpromise - (tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) - shape - (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) - (gs-ext2-ρ signature m n tp-t tp-u)))) - + (let ((tp-t (ensure-tpromise tp-t)) + (tp-u (ensure-tpromise tp-u))) + (tpromise + (tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) + shape + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-ext2-ρ signature m n tp-t tp-u))))) + +;; we invoke ensure-tpromise on just zp because it's the result of calling +;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to +;; be a tpromise as well, but currently in our workflow we never force tp +;; before passing it to this function, nor do we need scalar tp to be wrapped in +;; a tpromise. (define tpmake-ext1-∇ (λ (tp zp f signature m shape-fn shape) - (tpromise - (tcomp-ext1-∇ tp zp f signature m shape-fn) - shape - (box (list (tpromise-dst tp) (tpromise-dst zp))) - (gs-ext1-∇ signature m tp zp)))) + (let ((zp (ensure-tpromise zp))) + (tpromise + (tcomp-ext1-∇ tp zp f signature m shape-fn) + shape + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-ext1-∇ signature m tp zp))))) (define tpmake-ext2-∇ (λ (fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) - (tpromise - (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn - tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) - shape - (gdst-ext2-∇ tp-t0 tp-t1 tp-z) - (gs-ext2-∇ signature r0 r1 tp-t0 tp-t1 tp-z i)))) + (let ((tp-t0 (ensure-tpromise tp-t0)) + (tp-t1 (ensure-tpromise tp-t1)) + (tp-z (ensure-tpromise tp-z))) + (tpromise + (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + shape + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) + (gs-ext2-∇ signature r0 r1 tp-t0 tp-t1 tp-z i))))) (define tpmake-reshape (λ (tp shape) From fa4ab28c1a78cf3040e2d0fd1a8f0be84328bec4 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Mon, 15 Jul 2024 21:43:44 -0400 Subject: [PATCH 76/83] [add-lazy]Tidy up TODOs --- lazy/autodiff/A-autodiff.rkt | 2 +- lazy/autodiff/B-prims.rkt | 5 ----- lazy/tensors/c1-racket-runtime.rkt | 1 - malted/A-core.rkt | 1 - malted/D-gradient-descent.rkt | 20 +++++++++++++++++++- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index e7399a1..c23a9ba 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -73,7 +73,7 @@ (λ (y wrt) (let ((σ (∇σ y (hasheq)))) (map* (λ (d) - (↓ (hash-ref σ d 0.0))) + (hash-ref σ d 0.0)) wrt)))) (define ∇σ diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 07c1fb4..44b7c7d 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -7,11 +7,6 @@ #:property prop:procedure (λ (this . args) (apply (prim-proc this) args))) -;;TODO: move expects-preallocated?, functional->preallocated-1-ρ, -;;functional->preallocated-1-∇, functional->preallocated-2-ρ, -;;functional->preallocated-2-∇ here because they depend on the representation of -;;prims - (define prim1 (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) (let ((prim-sign (symbol->string (gensym 'prim1)))) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index a882e68..5443ef3 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -79,7 +79,6 @@ (define-namespace-anchor a) (define runtime - ;;TODO explicitly declare the names being included in this namespace (namespace-anchor->namespace a)) (provide runtime flat? flat:build-tensor flat:list->tensor diff --git a/malted/A-core.rkt b/malted/A-core.rkt index 4d50c79..4139f6b 100644 --- a/malted/A-core.rkt +++ b/malted/A-core.rkt @@ -1,7 +1,6 @@ #lang racket (require "../base.rkt") -;; TODO: This is not implementation independent. Figure out a fix (require (only-in "../lazy/tensors.rkt" ↓)) (define dot-product diff --git a/malted/D-gradient-descent.rkt b/malted/D-gradient-descent.rkt index b89bef4..259fe3d 100644 --- a/malted/D-gradient-descent.rkt +++ b/malted/D-gradient-descent.rkt @@ -12,12 +12,30 @@ (declare-hyper revs) (declare-hyper alpha) +;;TODO: abstract away the lazy implementation specific ↓ using a with-aspect + +;; For lazy implementation +#; +(define-syntax with-aspect + (syntax-rules () + [(_ 'gd-update f) + (lambda (pa g) (map* ↓ (f pa g)))] + [(_ _ f) f])) + +;; For other implementations +#; +(define-syntax with-aspect + (syntax-rules () + [(_ _ f) + f])) + (define gradient-descent (lambda (inflate deflate update) (λ (obj theta) (let ((ctr 0)) (let ((f (λ (big-theta) - (map (λ (pa g) (map* ↓ (update pa g))) + (map #;(with-aspect 'gd-update update) + (lambda (pa g) (map* ↓ (update pa g))) big-theta (gradient-of obj (map deflate big-theta)))))) From 96e78e6cef70f81681644857361ec97a4240af74 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Fri, 29 Dec 2023 16:26:30 -0500 Subject: [PATCH 77/83] =?UTF-8?q?[add-lazy]Refactor=20shape=20computation?= =?UTF-8?q?=20in=20tp-ext*-=CF=81=20functions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lazy/tensors/0-lazy.rkt | 57 ++++++++++++++---------------------- lazy/tensors/c3-compiler.rkt | 4 ++- 2 files changed, 25 insertions(+), 36 deletions(-) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 5fbcf2f..208d921 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -199,45 +199,32 @@ instructions refering to the same gensym variable [expects-prealloc? #f] [signature (format "~a" f)]) (λ (tp-t tp-u) - ;; TODO: Refactor out the code to compute the shape like in tp-ext1-ρ - (cond - ((and (number? tp-t) (number? tp-u)) - (f tp-t tp-u)) - [(and (tpromise? tp-t) (tpromise? tp-u) - (null? (tpromise-shape tp-t)) - (null? (tpromise-shape tp-u))) - ;; TODO: Fix the shape being used by using ext2-shapes - (tpmake-ext2-ρ-scalar f signature tp-t tp-u '())] - [expects-prealloc? - (let* ((s0 (tp-shape tp-t)) - (s1 (tp-shape tp-u)) - (sf0 (min-shape m s0)) - (sf1 (min-shape n s1)) - (sf-out (shape-fn sf0 sf1))) + (let* ((s0 (tp-shape tp-t)) + (s1 (tp-shape tp-u)) + (sf0 (min-shape m s0)) + (sf1 (min-shape n s1)) + (sf-out (shape-fn sf0 sf1))) + (cond + ((and (number? tp-t) (number? tp-u)) + (f tp-t tp-u)) + [(and (tpromise? tp-t) (tpromise? tp-u) + (null? (tpromise-shape tp-t)) + (null? (tpromise-shape tp-u))) + (tpmake-ext2-ρ-scalar f signature tp-t tp-u sf-out)] + [expects-prealloc? (tpmake-ext2-ρ tp-t tp-u f signature m n shape-fn (ext2-shapes s0 s1 m n sf-out - (λ (s-out . _) s-out))))] - [else - (let* ((s0 (tp-shape tp-t)) - (s1 (tp-shape tp-u)) - (sf0 (min-shape m s0)) - (sf1 (min-shape n s1)) - (sf-out (shape-fn sf0 sf1)) - (t-shape (min-shape m s0)) - (u-shape (min-shape n s1)) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-ρ - f - t-shape - u-shape - out-shape))) - (tpmake-ext2-ρ - tp-t tp-u - flat-f signature m n shape-fn - (ext2-shapes s0 s1 m n sf-out - (λ (s-out . _) s-out))))])))) + (λ (s-out . _) s-out)))] + [else + (let ((flat-f (functional->preallocated-2-ρ + f sf0 sf1 sf-out))) + (tpmake-ext2-ρ + tp-t tp-u + flat-f signature m n shape-fn + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))]))))) (define scalar-shape (λ (s0 [s1 '()]) '())) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index c2500ec..e137617 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -109,7 +109,9 @@ counter))) ;; TODO: Try using the signature field of tpromise struct as keys instead tcomp -;; references +;; references NOTE: We will need to generate signature out of the signature keys +;; every time we need to call hash-ref or hash-set, so maybe we shouldn't +;; implement this TODO (define cr-tpromise (λ (t counter uid) (match t From 67e3ad69fbd8f27f4c4ab093ef1dcdd1f31916c6 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 30 Dec 2023 21:14:33 -0500 Subject: [PATCH 78/83] [add-lazy]Fix extracting common subexpressions --- lazy/tensors/c3-compiler.rkt | 155 ++++++++++++++----------- lazy/tensors/test/test-c3-compiler.rkt | 61 +++++++++- 2 files changed, 145 insertions(+), 71 deletions(-) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index e137617..4b6f6bd 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -25,80 +25,94 @@ (define generate-ds-refs (λ (t) - (let-values (((t^ ref) (gdr-tpromise t 0))) + (let-values (((t^ ref) (gdr-tpromise t 0 (make-hasheq)))) t^))) (define gdr-tpromise - (λ (tp ref) + (λ (tp ref memo) (match tp ((tpromise tc s dss sign) - (let-values (((tc^ ref^) (gdr-tcomp tc ref))) + (let-values (((tc^ ref^) (gdr-tcomp tc ref memo))) (values (tpromise tc^ s dss sign) ref^)))))) (define gdr-tcomp - (λ (tc ref) - (match tc - ((? number?) (values tc ref)) - [(tcomp-list->tensor lst) - (for/fold - ((tcs '()) - (ref^ ref) - #:result (values (tcomp-list->tensor (reverse tcs)) ref^)) - ((l lst)) - (let-values (((tc ref^^) - (cond - ((tpromise? l) (gdr-tpromise l ref^)) - ((number? l) (values l ref^)) - (else (error 'gdr-list->tensor "Unexpected: ~a" l))))) - (values (cons tc tcs) ref^^)))] - [(tcomp-tref tp (tcomp-ds-ref #f)) - (let-values (((tp^ ref^) (gdr-tpromise tp ref))) - (values (tcomp-tref tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] - [(tcomp-trefs tp (tcomp-ds-ref #f)) - (let-values (((tp^ ref^) (gdr-tpromise tp ref))) - (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z - out-ref0 - out-ref1 i) - (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref)) - ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^)) - ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^))) - (cond - ((and (eqv? i 0) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref0)))) - (set-ext2-∇-result-res! out-ref0 (tcomp-ds-ref ref^^^))) - ((and (eqv? i 1) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) - (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) - (values (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ - out-ref0 out-ref1 i) - (add1 ref^^^)))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) - (let*-values (((tp^ ref^) (gdr-tpromise tp ref)) - ((zp^ ref^^) (gdr-tpromise zp ref^))) - (values (tcomp-ext1-∇ tp^ zp^ f sign m shape-fn) ref^^))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) - (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref)) - ((tp-u^ ref^^) (gdr-tpromise tp-u ref^))) - (values (tcomp-ext2-ρ-scalar f sign tp-t^ tp-u^) ref^^))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) - (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref)) - ((tp-u^ ref^^) (gdr-tpromise tp-u ref^))) - (values (tcomp-ext2-ρ tp-t^ tp-u^ f sign m n shape-fn) ref^^))] - [(tcomp-ext1-ρ-scalar f sign tp) - (let-values (((tp^ ref^) (gdr-tpromise tp ref))) - (values (tcomp-ext1-ρ-scalar f sign tp^) ref^))] - [(tcomp-ext1-ρ f sign m shape-fn tp) - (let-values (((tp^ ref^) (gdr-tpromise tp ref))) - (values (tcomp-ext1-ρ f sign m shape-fn tp^) ref^))] - [(tcomp-reshape s tp) - (let-values (((tp^ ref^) (gdr-tpromise tp ref))) - (values (tcomp-reshape s tp^) ref^))] - [(tcomp-ds-ref #f) (values (tcomp-ds-ref ref) (add1 ref))] - ;;need these cases for testing compiler invariant - [(tcomp-let lhs rhs body) - (let*-values (((rhs^ ref^) (gdr-tpromise rhs ref)) - ((body^ ref^^) (gdr-tpromise body ref^))) - (values (tcomp-let lhs rhs^ body^) ref^^))] - [(tcomp-var name) (values (tcomp-var name) ref)]))) + (λ (tc ref memo) + (cond + ((hash-ref memo tc #f) + => + (λ (res/ref-count) + (match-let (((cons res ref-count) res/ref-count)) + (values res (+ ref ref-count))))) + (else + (let-values + (((res ref^) + (match tc + ((? number?) (values tc ref)) + [(tcomp-list->tensor lst) + (for/fold + ((tcs '()) + (ref^ ref) + #:result (values (tcomp-list->tensor (reverse tcs)) ref^)) + ((l lst)) + (let-values (((tc ref^^) + (cond + ((tpromise? l) (gdr-tpromise l ref^ memo)) + ((number? l) (values l ref^)) + (else (error 'gdr-list->tensor + "Unexpected: ~a" l))))) + (values (cons tc tcs) ref^^)))] + [(tcomp-tref tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-tref tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-trefs tp (tcomp-ds-ref #f)) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + out-ref0 + out-ref1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^ memo)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^ memo))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref0)))) + (set-ext2-∇-result-res! out-ref0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) + (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ + out-ref0 out-ref1 i) + (add1 ref^^^)))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) + ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) + (values (tcomp-ext1-∇ tp^ zp^ f sign m shape-fn) ref^^))] + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-ext2-ρ-scalar f sign tp-t^ tp-u^) ref^^))] + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-ext2-ρ tp-t^ tp-u^ f sign m n shape-fn) ref^^))] + [(tcomp-ext1-ρ-scalar f sign tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-ext1-ρ-scalar f sign tp^) ref^))] + [(tcomp-ext1-ρ f sign m shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-ext1-ρ f sign m shape-fn tp^) ref^))] + [(tcomp-reshape s tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-reshape s tp^) ref^))] + [(tcomp-ds-ref #f) (values (tcomp-ds-ref ref) (add1 ref))] + ;;need these cases for testing compiler invariant + [(tcomp-let lhs rhs body) + (let*-values (((rhs^ ref^) (gdr-tpromise rhs ref memo)) + ((body^ ref^^) (gdr-tpromise body ref^ memo))) + (values (tcomp-let lhs rhs^ body^) ref^^))] + [(tcomp-var name) (values (tcomp-var name) ref)]))) + (hash-set! memo tc (cons res (- ref^ ref))) + (values res ref^)))))) ;; Count references so that the tcomp AST nodes that refer to the same memory ;; location i.e. common AST nodes get extracted by let-binding them in the @@ -109,9 +123,7 @@ counter))) ;; TODO: Try using the signature field of tpromise struct as keys instead tcomp -;; references NOTE: We will need to generate signature out of the signature keys -;; every time we need to call hash-ref or hash-set, so maybe we shouldn't -;; implement this TODO +;; references (define cr-tpromise (λ (t counter uid) (match t @@ -191,7 +203,10 @@ (run-compiler-ecs (ecs-tpromise t counter) '()))) (for/fold ((body instrs)) ((binding bindings)) - (tcomp-let (car binding) (cdr binding) body))))) + (tpromise (tcomp-let (car binding) + (tpromise (cdr binding) '() (box '()) (box '())) + body) + '() (box '()) (box '())))))) (define ecs-tpromise (λ (tc counter) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index 96ab6f1..c38707c 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -109,4 +109,63 @@ (string-append "eq? equivalent flat tensors and trefs index lists" " used to construct the source AST must" " be eq? equivalent in the data segment as well."))) - ) + + (define count-tcomp-var + (λ (tp) + (ctv-tcomp (tpromise-tensor tp)))) + + (define ctv-tcomp + (λ (tc) + (match tc + ((? number?) 0) + [(tcomp-list->tensor lst) + (for/sum + ((l lst)) + (cond + ((tpromise? l) (count-tcomp-var l)) + ((number? l) 0) + (else (error 'cdsr-list->tensor "Unexpected: ~a" l))))] + [(tcomp-tref tp _) (count-tcomp-var tp)] + [(tcomp-trefs tp _) (count-tcomp-var tp)] + [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + out-ref0 + out-ref1 i) + (let ((c0 (count-tcomp-var tp-t0)) + (c1 (count-tcomp-var tp-t1)) + (cz (count-tcomp-var tp-z))) + (+ c0 c1 cz))] + [(tcomp-ext1-∇ tp zp f sign m shape-fn) + (let ((ct (count-tcomp-var tp)) + (cz (count-tcomp-var zp))) + (+ ct cz))] + [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (let ((ct (count-tcomp-var tp-t)) + (cu (count-tcomp-var tp-u))) + (+ ct cu))] + [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (let ((ct (count-tcomp-var tp-t)) + (cu (count-tcomp-var tp-u))) + (+ ct cu))] + [(tcomp-ext1-ρ-scalar f sign tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ f sign m shape-fn tp) (count-tcomp-var tp)] + [(tcomp-reshape s tp) (count-tcomp-var tp)] + [(tcomp-ds-ref i) 0] + [(tcomp-let lhs rhs body) + (let ((cr (count-tcomp-var rhs)) + (cb (count-tcomp-var body))) + (+ cr cb))] + [(tcomp-var name) 1]))) + + (define get-common-subexprs + (λ (tp) + (let ((instrs (generate-ds-refs tp))) + (extract-common-subexpressions instrs (count-references instrs))))) + + (check-equal? + (count-tcomp-var (get-common-subexprs (get-test-program 'common-subexpression))) + 2) + (check-equal? + (count-tcomp-var + (get-common-subexprs (get-test-program 'nested-common-subexpression))) + 2) +) From 9bbe93b16f19b2db8da830fdc8c947487247f655 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Sat, 27 Jan 2024 07:38:46 -0500 Subject: [PATCH 79/83] [add-lazy]Fix intepreter and add compiler invariance tests --- lazy/tensors/c1-racket-runtime.rkt | 6 +- lazy/tensors/c2-interpreter.rkt | 112 +++++++++++++--------- lazy/tensors/c3-compiler.rkt | 62 ++++++------ lazy/tensors/test/test-c2-interpreter.rkt | 32 +++++++ lazy/tensors/test/test-c3-compiler.rkt | 80 ++++++++++++---- 5 files changed, 188 insertions(+), 104 deletions(-) create mode 100644 lazy/tensors/test/test-c2-interpreter.rkt diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 5443ef3..0600e04 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -5,7 +5,7 @@ (struct ext2-∇-result (res) #:mutable #:transparent) -(define ext2-∇-forcer +(define ext2-∇-forcer! (λ (fᵈ r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) (let* ((f0 (ensure-flat t0)) (f1 (ensure-flat t1)) @@ -83,5 +83,5 @@ (provide runtime flat? flat:build-tensor flat:list->tensor flat:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! - ext2-∇-forcer scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ - flat flat-store flat-offset flat-ext1-ρ data-segment) + ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ + flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index 82c0604..423029c 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -3,91 +3,111 @@ (require "c0-ast.rkt") (require (only-in "c1-racket-runtime.rkt" runtime flat? flat:build-tensor flat:list->tensor - flat:tref rt:trefs ext2-∇-result-res ext2-∇-forcer - scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store - flat-offset flat-ext1-ρ data-segment)) + set-ext2-∇-result-res! flat:tref rt:trefs ext2-∇-result-res + ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ + flat flat-store flat-offset flat-ext1-ρ data-segment + data-segment-ref)) -(define interp-tensor-tcomp - (λ (tc env ds) +(define interp-tcomp + (λ (tc env) (match tc [(tcomp-list->tensor lst) (let ((eval-list (for/list ((arg lst)) - (interp-tensor-expr arg env ds)))) + (cond + ((tpromise? arg) (interp-tpromise arg env)) + ((number? arg) arg) + (else (error 'interp-list->tensor "Unexpected: ~a" arg)))))) (flat:list->tensor eval-list))] - [(tcomp-tref tp i) - (flat:tref (interp-tensor-expr tp env ds) - (interp-tensor-expr i env ds))] - [(tcomp-trefs tp b) - (rt:trefs (interp-tensor-expr tp env ds) - (interp-tensor-expr b env ds))] + [(tcomp-tref tp (and i (tcomp-ds-ref _))) + (flat:tref (interp-tpromise tp env) + (interp-tcomp i env))] + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) + (rt:trefs (interp-tpromise tp env) + (interp-tcomp b env))] [(tcomp-ext2-∇ fᵈ _ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) - ;; TODO: fix this case because we now use the data segment rather than - ;; ext2-∇-result for output - (let* ([b (if (zero? i) out0 out1)] - [v (ext2-∇-result-res b)]) + (let ((t0-instrs (interp-tpromise tp-t0 env)) + (t1-instrs (interp-tpromise tp-t1 env)) + (z-instrs (interp-tpromise tp-z env))) (cond - ((eqv? v 'uncalculated) - (ext2-∇-forcer fᵈ r0 r1 shape-fn - (interp-tensor-expr tp-t0 env ds) - (interp-tensor-expr tp-t1 env ds) - (interp-tensor-expr tp-z env ds) - out0 out1) - (ext2-∇-result-res b)) - (else v)))] + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index)))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index))))) + (let* ((out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1))) + (index (if (zero? i) out-idx0 out-idx1)) + (v (data-segment-ref index))) + (cond + ((eqv? v 'uncalculated) + (ext2-∇-forcer! fᵈ r0 r1 shape-fn + t0-instrs t1-instrs + z-instrs out-idx0 out-idx1) + (data-segment-ref index)) + (else v))))] [(tcomp-ext1-∇ tp zp f _ m shape-fn) (scalarize (flat-ext1-∇ f m shape-fn - (ensure-flat (interp-tensor-expr tp env ds)) - (ensure-flat (interp-tensor-expr zp env ds))))] + (ensure-flat (interp-tpromise tp env)) + (ensure-flat (interp-tpromise zp env))))] [(tcomp-ext2-ρ-scalar f _ tp-t tp-u) - (f (interp-tensor-expr tp-t env ds) (interp-tensor-expr tp-u env ds))] + (f (interp-tpromise tp-t env) (interp-tpromise tp-u env))] [(tcomp-ext2-ρ tp-t tp-u f _ m n shape-fn) (scalarize (flat-ext2-ρ f m n shape-fn - (ensure-flat (interp-tensor-expr tp-t env ds)) - (ensure-flat (interp-tensor-expr tp-u env ds))))] + (ensure-flat (interp-tpromise tp-t env)) + (ensure-flat (interp-tpromise tp-u env))))] [(tcomp-ext1-ρ-scalar f _ tp) - (f (interp-tensor-expr tp env ds))] + (f (interp-tpromise tp env))] [(tcomp-ext1-ρ f _ m shape-fn tp) (scalarize (flat-ext1-ρ f m shape-fn - (ensure-flat (interp-tensor-expr tp env ds))))] + (ensure-flat (interp-tpromise tp env))))] [(tcomp-reshape s tp) - (flat s - (flat-store (interp-tensor-expr tp env ds)) - (flat-offset (interp-tensor-expr tp env ds)))] + (let ([interp-tp (interp-tpromise tp env)]) + (flat s (flat-store interp-tp) (flat-offset interp-tp)))] [(tcomp-let lhs rhs body) - (interp-tensor-expr + (interp-tpromise body (cons (cons lhs - (interp-tensor-expr rhs env ds)) - env) - ds)] + (interp-tpromise rhs env)) + env))] [(tcomp-var name) (cond ((assv name env) => (λ (p) (cdr p))) (else (error 'interpret-free "Free variable: ~a" name)))] + [(tcomp-ds-ref #f) + (let ([out (data-segment-ref (current-ds-ref-index))]) + (current-ds-ref-index (add1 (current-ds-ref-index))) + out)] [(tcomp-ds-ref index) - (vector-ref ds index)]))) + ;; This case is run only for languages where the tcomp-ds-ref indices are + ;; generated by the generate-ds-refs pass. + (data-segment-ref index)]))) -(define interp-tensor-expr - (λ (t env ds) +(define interp-tpromise + (λ (t env) (match t - [(tpromise tc _ _ _) (interp-tensor-expr tc env ds)] - [v #:when (or (flat? v) (pair? v) (number? v)) v] - [(tcomp) (interp-tensor-tcomp t env ds)]))) + [(tpromise tc _ _ _) (interp-tcomp tc env)]))) +(define current-ds-ref-index (make-parameter #f)) (define interp-tensor - (λ (t ds) - (interp-tensor-expr t '() ds))) + (λ (tp) + (parameterize ([current-ds-ref-index 0] + [data-segment (dst->data-segment (tpromise-dst tp))]) + (interp-tpromise tp '())))) (define interp-racket (lambda (instrs ds) (parameterize ((data-segment ds)) (eval instrs runtime)))) +(include "test/test-c2-interpreter.rkt") (provide interp-racket interp-tensor) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 4b6f6bd..6b8437f 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -123,7 +123,9 @@ counter))) ;; TODO: Try using the signature field of tpromise struct as keys instead tcomp -;; references +;; references. The naive way to do this might be inefficient because of the +;; constant conversion between the tree representation of the signature and the +;; numeric hash signature. (define cr-tpromise (λ (t counter uid) (match t @@ -167,9 +169,9 @@ ((tpromise? l) (cr-tpromise l counter^^ uid^^)) ((number? l) (values counter^^ uid^^)) (else (error 'cr-list->tensor "Unexpected: ~a" l))))] - [(tcomp-tref tp i) + [(tcomp-tref tp (and i (tcomp-ds-ref _))) (cr-tpromise tp counter^ uid^)] - [(tcomp-trefs tp b) + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (cr-tpromise tp counter^ uid^)] [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) @@ -201,7 +203,13 @@ (λ (t counter) (let-values (((instrs bindings) (run-compiler-ecs (ecs-tpromise t counter) '()))) - (for/fold ((body instrs)) + (for/fold ((body instrs) + #:result ;; set-box! the data-segment of result so that + ;; applying it to interp-tensor works + (begin + (set-box! (tpromise-dst body) (unbox (tpromise-dst t))) + (set-box! (tpromise-sign body) (unbox (tpromise-sign t))) + body)) ((binding bindings)) (tpromise (tcomp-let (car binding) (tpromise (cdr binding) '() (box '()) (box '())) @@ -245,12 +253,12 @@ instrs-list-compiler (λ (instrs-list) (inj-ecs-tcomp (tcomp-list->tensor instrs-list) tc-counter-data))))] - [(tcomp-tref tp i) + [(tcomp-tref tp (and i (tcomp-ds-ref _))) (->ecs (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-tref instrs i) tc-counter-data)))] - [(tcomp-trefs tp b) + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (->ecs (ecs-tpromise tp counter) (λ (instrs) @@ -374,18 +382,19 @@ (match tc [v #:when (number? v) v] [(tcomp-list->tensor lst) - (let ((instrs-list (map (λ (t) - (cond - ((tpromise? t) (gr-tpromise t)) - ((number? t) t) - (else (error 'gr-list->tensor "Unexpected: ~a" t)))) - lst))) + (let ((instrs-list + (map (λ (t) + (cond + ((tpromise? t) (gr-tpromise t)) + ((number? t) t) + (else (error 'gr-list->tensor "Unexpected: ~a" t)))) + lst))) `(flat:list->tensor (list ,@instrs-list)))] - [(tcomp-tref tp i) + [(tcomp-tref tp (and i (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (i-instrs (gr-tcomp i))) `(flat:tref ,instrs ,i-instrs))] - [(tcomp-trefs tp b) + [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (b-instrs (gr-tcomp b))) `(rt:trefs ,instrs ,b-instrs))] @@ -400,9 +409,9 @@ [v (data-segment-ref index)]) (cond ((eqv? v 'uncalculated) - (ext2-∇-forcer ,fᵈ ,r0 ,r1 ,shape-fn - ,t0-instrs ,t1-instrs - ,z-instrs ,out-idx0 ,out-idx1) + (ext2-∇-forcer! ,fᵈ ,r0 ,r1 ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out-idx0 ,out-idx1) (data-segment-ref index)) (else v)))))] [(tcomp-ext1-∇ tp zp f sign m shape-fn) @@ -489,23 +498,6 @@ (parameterize ([current-namespace runtime]) (compile-syntax (expand r))))) -;;TODO: update this for new compiler passes -(define compile-tensor/checks - (λ (t) - t - #; - (let-values (((eds-instrs ds) (extract-data-segment t))) - (flat:check-tensor-equal? (interp-tensor t) (interp-tensor eds-instrs)) - (let ((counter (count-references t))) - (let ((extracted (extract-common-subexpressions t counter))) - (flat:check-tensor-equal? (interp-tensor t) (interp-tensor extracted)) - (for/list ((cd (hash-values (count-references extracted)))) - (check-equal? (counter-data-ref-count cd) 1)) - (let-values (((rkt env) (generate-racket extracted))) - (flat:check-tensor-equal? (interp-tensor extracted) - (interp-racket rkt env)) - (values rkt env))))))) - (define get-compiled (λ (t) (let-values (((instrs env) @@ -514,5 +506,5 @@ ,instrs)))) (include "test/test-c3-compiler.rkt") -(provide get-compiled compile-tensor compile-tensor/checks print-compiler? +(provide get-compiled compile-tensor print-compiler? (rename-out (cache compiler-cache))) diff --git a/lazy/tensors/test/test-c2-interpreter.rkt b/lazy/tensors/test/test-c2-interpreter.rkt new file mode 100644 index 0000000..c20dfdb --- /dev/null +++ b/lazy/tensors/test/test-c2-interpreter.rkt @@ -0,0 +1,32 @@ +(module+ test + (require rackunit) + (require "B-test-programs.rkt") + (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + + (for (((test-name test-data) (in-hash test-programs))) + (match-define (test-program-data th res) test-data) + (match res + ((eval-res-1 res) + (let* ((tp (th)) + (interped (interp-tensor tp))) + (flat:check-tensor-equal? + interped res + (format "Expected result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp) (flat:shape interped)))) + ((eval-res-2 res1 res2) + (let*-values (((tp1 tp2) (th)) + ((interped1) (interp-tensor tp1)) + ((interped2) (interp-tensor tp2))) + (flat:check-tensor-equal? + interped1 res1 + (format "Expected first result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp1) (flat:shape interped1)) + (flat:check-tensor-equal? + interped2 res2 + (format "Expected second result doesn't match in test case ~a" + test-name)) + (check-equal? (tpromise-shape tp2) (flat:shape interped2)))))) + +) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index c38707c..0585290 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -2,29 +2,68 @@ (require rackunit) (require "B-test-programs.rkt") (require "0-lazy.rkt") + (require "c2-interpreter.rkt") + (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (define current-test-program-name (make-parameter #f)) (define-check (check-compiler-invariants tp) - (let-values (((instrs ds) (compile-tensor tp))) - (with-check-info - (('data-segment ds) - ('instrs instrs)) - (define test-name-string - (cond - ((current-test-program-name) (format "In test case: ~a" - (current-test-program-name))) - (else ""))) - 'ok - ;;TODO: Add a check to ensure the number of tcomp-ds-ref occurring in - ;;the input equals the size of the data segment - #; - (for ((name/flat ds)) - (unless (and (flat:flat? (cdr name/flat)) - (not (null? (flat:flat-shape (cdr name/flat))))) - (fail-check (format (string-append "Value associated with the variable" - " ~a should be a flat tensor. " - "Associated value found: ~a") - (car name/flat) (cdr name/flat)))))))) + (define ds (dst->data-segment (tpromise-dst tp))) + (define signature (sign (tpromise-sign tp))) + (define interp-tp (interp-tensor tp)) + (with-check-info + (('data-segment ds) + ('signature signature) + ('input-computation (tpromise-tensor tp)) + ('expected-interpretation interp-tp) + ('test-name (current-test-program-name))) + (for ((d ds)) + (unless (or (number? d) + (flat:flat? d) + (eqv? d 'uncalculated)) + (fail-check (format (string-append "Data segment should only contain flat tensors " + ", the symbol 'uncalculated or numbers." + " Found: ~a") + d)))) + (parameterize ((cache (make-hash))) + (let* ((instrs-dsr (generate-ds-refs tp)) + (interp-dsr (interp-tensor instrs-dsr))) + (unless (flat:tensor-equal? interp-dsr interp-tp) + (fail-check (format + (string-append + "Result of interpreting pass generate-ds-ref doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n")) + interp-dsr)) + (let ((counter (count-references instrs-dsr))) + (let* ((extracted (extract-common-subexpressions instrs-dsr counter)) + (interp-extracted (interp-tensor extracted))) + (unless (flat:tensor-equal? interp-extracted interp-tp) + (fail-check (format + (string-append + "Result of interpreting pass" + " extract-common-subexpression doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n")) + interp-extracted)) + (let* ((gr (generate-racket extracted)) + (rkt (compile-racket gr)) + (interp-rkt (interp-racket rkt ds))) + (unless (flat:tensor-equal? interp-rkt interp-tp) + (fail-check (format + (string-append + "Result of interpreting compiled racket code doesn't" + " match expected interpretation. Actual " + "interpretation: ~a~n")) + interp-rkt)) + (hash-set! (cache) signature rkt) + (compile-tensor tp) + (unless (eqv? (hash-count (cache)) 1) + (fail-check (format + (string-append + "Compiling the same tpromise again shouldn't" + " change the number of entries in the cache." + " Number of cache entries: ~a~n") + (hash-count (cache)))))))))))) (for (((test-name test-data) (in-hash test-programs))) (match-define (test-program-data th res) test-data) @@ -37,6 +76,7 @@ (let*-values (((tp1 tp2) (th))) (check-compiler-invariants tp1) (check-compiler-invariants tp2)))))) + (define-check (check-signatures-equal? t1 t2) (let ((sig1 (tpromise-sign t1)) (sig2 (tpromise-sign t2))) From 64c313f8f90bd4fa0bf77cf4ca7727bb70296734 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:31:22 -0400 Subject: [PATCH 80/83] [add-lazy]Add zeroes as a primitive --- lazy.rkt | 4 +++- lazy/ext-ops.rkt | 2 +- lazy/ext-ops/A-scalar-ops.rkt | 5 ++++- lazy/no-duals-no-overrides.rkt | 2 +- lazy/no-duals.rkt | 2 +- lazy/no-overrides.rkt | 2 +- 6 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lazy.rkt b/lazy.rkt index 7b5bba9..0505225 100644 --- a/lazy.rkt +++ b/lazy.rkt @@ -8,6 +8,8 @@ (require "lazy/ext-ops.rkt") (provide + tolerance + len ref refr tref tlen list->tensor tensor build-tensor @@ -32,7 +34,7 @@ (d-concat concat) (d-concat-n concat-n)) +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ diff --git a/lazy/ext-ops.rkt b/lazy/ext-ops.rkt index 67709fa..8a58d1a 100644 --- a/lazy/ext-ops.rkt +++ b/lazy/ext-ops.rkt @@ -19,7 +19,7 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) (provide =-0-0 <-0-0 <=-0-0 >-0-0 >=-0-0 =-1 <-1 >-1 <=-1 >=-1 !=-1) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt index 06bd0f9..2049b41 100644 --- a/lazy/ext-ops/A-scalar-ops.rkt +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -121,6 +121,9 @@ (λ (x) (*-ρ x x))) +(define zeroes-ρ + (ext1-ρ (λ (_) 0.0) 0)) + (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 @@ -132,4 +135,4 @@ +-ρ --ρ *-ρ /-ρ expt-ρ exp-ρ log-ρ abs-ρ - rectify-ρ sqrt-ρ sqr-ρ) + rectify-ρ sqrt-ρ sqr-ρ zeroes-ρ) diff --git a/lazy/no-duals-no-overrides.rkt b/lazy/no-duals-no-overrides.rkt index ac07a7a..07ca22e 100644 --- a/lazy/no-duals-no-overrides.rkt +++ b/lazy/no-duals-no-overrides.rkt @@ -20,7 +20,7 @@ ;; From ext-ops +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ concat-ρ concat-n-ρ flatten-ρ diff --git a/lazy/no-duals.rkt b/lazy/no-duals.rkt index 927c8c7..cd1bcaf 100644 --- a/lazy/no-duals.rkt +++ b/lazy/no-duals.rkt @@ -20,7 +20,7 @@ ;; From ext-ops (rename-out (+-ρ +) (--ρ -) (*-ρ *) (/-ρ /) (rectify-ρ rectify) - (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) + (exp-ρ exp) (log-ρ log) (expt-ρ expt) (sqrt-ρ sqrt) (sqr-ρ sqr) (zeroes-ρ zeroes) (sum-ρ sum) (abs-ρ abs) (*-2-1-ρ *-2-1) (argmax-ρ argmax) (max-ρ max) (sum-cols-ρ sum-cols) (correlate-ρ correlate) (flatten-ρ flatten) (concat-ρ concat) (concat-n-ρ concat-n)) diff --git a/lazy/no-overrides.rkt b/lazy/no-overrides.rkt index 35dcbdd..05844b7 100644 --- a/lazy/no-overrides.rkt +++ b/lazy/no-overrides.rkt @@ -34,7 +34,7 @@ d-flatten d-concat d-concat-n +-ρ --ρ *-ρ /-ρ rectify-ρ - exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ + exp-ρ log-ρ expt-ρ sqrt-ρ sqr-ρ zeroes-ρ sum-ρ abs-ρ *-2-1-ρ argmax-ρ max-ρ sum-cols-ρ correlate-ρ flatten-ρ concat-ρ concat-n-ρ From 554e84628fea6460434cb669014057102ec0c8f5 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Thu, 20 Jun 2024 03:07:40 -0400 Subject: [PATCH 81/83] [add-lazy]Unify prim sign generation in acc & lazy impls --- lazy/autodiff/B-prims.rkt | 24 +++--- lazy/tensors/0-lazy.rkt | 162 ++++++++++++++++++++------------------ 2 files changed, 101 insertions(+), 85 deletions(-) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 44b7c7d..94796d3 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -8,11 +8,13 @@ (apply (prim-proc this) args))) (define prim1 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) - (let ((prim-sign (symbol->string (gensym 'prim1)))) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? - (λ (da) - (prim1-dual ρ-fn ∇-fn da)))))) + (let ((id 0)) + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (string-append "p1" (~r id #:base 16)))) + (set! id (add1 id)) + (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (λ (da) + (prim1-dual ρ-fn ∇-fn da))))))) (define prim1-dual (λ (ρ-fn ∇-fn da) @@ -24,11 +26,13 @@ ((κ da) da ga σ)))))))) (define prim2 - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) - (let ((prim-sign (symbol->string (gensym 'prim2)))) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? - (λ (da db) - (prim2-dual ρ-fn ∇-fn da db)))))) + (let ((id 0)) + (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (let ((prim-sign (string-append "p2" (~r id #:base 16)))) + (set! id (add1 id)) + (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (λ (da db) + (prim2-dual ρ-fn ∇-fn da db))))))) (define prim2-dual (λ (ρ-fn ∇-fn da db) diff --git a/lazy/tensors/0-lazy.rkt b/lazy/tensors/0-lazy.rkt index 208d921..61f3210 100644 --- a/lazy/tensors/0-lazy.rkt +++ b/lazy/tensors/0-lazy.rkt @@ -172,102 +172,114 @@ instructions refering to the same gensym variable ;; argument. The signature argument is only supposed to be passed within the ;; definition of ext1 and ext2 functions in B-prims.rkt. (define tp-ext1-ρ - (λ (f m - [shape-fn scalar-shape] - [expects-prealloc? #f] - [signature (format "~a" f)]) - (λ (tp) - (let* ((in-shape (tp-shape tp)) - (base-shape (min-shape m in-shape)) - (shape-fn-out (shape-fn base-shape)) - (out-shape (merge-shapes in-shape m shape-fn-out))) - (cond - [(scalar? tp) (f tp)] - [(and (tpromise? tp) - (null? (tpromise-shape tp))) - (tpmake-ext1-ρ-scalar f signature tp out-shape)] - [expects-prealloc? - (tpmake-ext1-ρ f signature m shape-fn tp out-shape)] - [else - (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out))) - (tpmake-ext1-ρ flat-f signature m shape-fn tp out-shape))]))))) + (let ((id -1)) + (λ (f m + [shape-fn scalar-shape] + [expects-prealloc? #f] + [prim-sign (begin + (set! id (add1 id)) + (string-append "re1" (~r id #:base 16)))]) + (λ (tp) + (let* ((in-shape (tp-shape tp)) + (base-shape (min-shape m in-shape)) + (shape-fn-out (shape-fn base-shape)) + (out-shape (merge-shapes in-shape m shape-fn-out))) + (cond + [(scalar? tp) (f tp)] + [(and (tpromise? tp) + (null? (tpromise-shape tp))) + (tpmake-ext1-ρ-scalar f prim-sign tp out-shape)] + [expects-prealloc? + (tpmake-ext1-ρ f prim-sign m shape-fn tp out-shape)] + [else + (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out))) + (tpmake-ext1-ρ flat-f prim-sign m shape-fn tp out-shape))])))))) ;; See comment for tp-ext1-ρ (define tp-ext2-ρ - (λ (f m n + (let ((id -1)) + (λ (f m n [shape-fn scalar-shape] [expects-prealloc? #f] - [signature (format "~a" f)]) - (λ (tp-t tp-u) - (let* ((s0 (tp-shape tp-t)) - (s1 (tp-shape tp-u)) - (sf0 (min-shape m s0)) - (sf1 (min-shape n s1)) - (sf-out (shape-fn sf0 sf1))) - (cond - ((and (number? tp-t) (number? tp-u)) - (f tp-t tp-u)) - [(and (tpromise? tp-t) (tpromise? tp-u) - (null? (tpromise-shape tp-t)) - (null? (tpromise-shape tp-u))) - (tpmake-ext2-ρ-scalar f signature tp-t tp-u sf-out)] - [expects-prealloc? - (tpmake-ext2-ρ - tp-t tp-u - f signature m n shape-fn - (ext2-shapes s0 s1 m n sf-out - (λ (s-out . _) s-out)))] - [else - (let ((flat-f (functional->preallocated-2-ρ - f sf0 sf1 sf-out))) + [prim-sign (begin + (set! id (add1 id)) + (string-append "re2" (~r id #:base 16)))]) + (λ (tp-t tp-u) + (let* ((s0 (tp-shape tp-t)) + (s1 (tp-shape tp-u)) + (sf0 (min-shape m s0)) + (sf1 (min-shape n s1)) + (sf-out (shape-fn sf0 sf1))) + (cond + ((and (number? tp-t) (number? tp-u)) + (f tp-t tp-u)) + [(and (tpromise? tp-t) (tpromise? tp-u) + (null? (tpromise-shape tp-t)) + (null? (tpromise-shape tp-u))) + (tpmake-ext2-ρ-scalar f prim-sign tp-t tp-u sf-out)] + [expects-prealloc? (tpmake-ext2-ρ tp-t tp-u - flat-f signature m n shape-fn + f prim-sign m n shape-fn (ext2-shapes s0 s1 m n sf-out - (λ (s-out . _) s-out))))]))))) + (λ (s-out . _) s-out)))] + [else + (let ((flat-f (functional->preallocated-2-ρ + f sf0 sf1 sf-out))) + (tpmake-ext2-ρ + tp-t tp-u + flat-f prim-sign m n shape-fn + (ext2-shapes s0 s1 m n sf-out + (λ (s-out . _) s-out))))])))))) (define scalar-shape (λ (s0 [s1 '()]) '())) ;; See comment for tp-ext1-ρ (define tp-ext1-∇ - (λ (f m + (let ((id -1)) + (λ (f m [shape-fn scalar-shape] [expects-prealloc? #f] - [signature (format "~a" f)]) - (λ (tp zp) - ;; - (cond - ((number? tp) (f tp zp)) - (expects-prealloc? - (tpmake-ext1-∇ tp zp f signature m shape-fn (tp-shape tp))) - (else - (let* ((in-shape (tpromise-shape tp)) - (base-shape (min-shape m in-shape)) - (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tpmake-ext1-∇ tp zp flat-f signature m shape-fn (tp-shape tp)))))))) + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne1" (~r id #:base 16)))]) + (λ (tp zp) + ;; + (cond + ((number? tp) (f tp zp)) + (expects-prealloc? + (tpmake-ext1-∇ tp zp f prim-sign m shape-fn (tp-shape tp))) + (else + (let* ((in-shape (tpromise-shape tp)) + (base-shape (min-shape m in-shape)) + (out-shape (shape-fn base-shape)) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) + (tpmake-ext1-∇ tp zp flat-f prim-sign m shape-fn (tp-shape tp))))))))) ;; See comment for tp-ext1-ρ (define tp-ext2-∇ - (λ (f m n + (let ((id -1)) + (λ (f m n [shape-fn scalar-shape] [expects-prealloc? #f] - [signature (format "~a" f)]) - (let ((tp-f - (λ (f tp-t tp-u tp-z) - (tp-d-ext2^ f signature m n shape-fn - tp-t tp-u tp-z)))) - (λ (tp-t tp-u tp-z) - (cond - (expects-prealloc? - (tp-f f tp-t tp-u tp-z)) - [else (let* ((t-shape (min-shape m (tp-shape tp-t))) - (u-shape (min-shape n (tp-shape tp-u))) - (out-shape (shape-fn t-shape u-shape)) - (flat-f (functional->preallocated-2-∇ - f t-shape u-shape out-shape))) - (tp-f flat-f tp-t tp-u tp-z))]))))) + [prim-sign (begin + (set! id (add1 id)) + (string-append "ne2" (~r id #:base 16)))]) + (let ((tp-f + (λ (f tp-t tp-u tp-z) + (tp-d-ext2^ f prim-sign m n shape-fn + tp-t tp-u tp-z)))) + (λ (tp-t tp-u tp-z) + (cond + (expects-prealloc? + (tp-f f tp-t tp-u tp-z)) + [else (let* ((t-shape (min-shape m (tp-shape tp-t))) + (u-shape (min-shape n (tp-shape tp-u))) + (out-shape (shape-fn t-shape u-shape)) + (flat-f (functional->preallocated-2-∇ + f t-shape u-shape out-shape))) + (tp-f flat-f tp-t tp-u tp-z))])))))) (define tp-d-ext2^ (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) From 5f08b10b5a18a1471744de6309744401db8968a4 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:40:54 -0400 Subject: [PATCH 82/83] [add-lazy]Switch compiled tensor runtime to acc tensor impl --- lazy/autodiff/A-autodiff.rkt | 5 +- lazy/autodiff/B-prims.rkt | 89 +++++++-- lazy/autodiff/E-print.rkt | 2 +- lazy/autodiff/test/test-E-print.rkt | 50 ++--- lazy/ext-ops/A-scalar-ops.rkt | 134 ++++++++++---- lazy/ext-ops/B-comparators.rkt | 36 ++-- lazy/ext-ops/C-star-2-1.rkt | 60 ++++-- lazy/ext-ops/D-sum.rkt | 76 ++++++-- lazy/ext-ops/E-argmax.rkt | 43 ++++- lazy/ext-ops/F-max.rkt | 57 +++++- lazy/ext-ops/G-correlate.rkt | 89 +++++++-- lazy/ext-ops/I-flatten.rkt | 25 ++- lazy/ext-ops/K-concat.rkt | 55 +++++- lazy/ext-ops/test/test-G-correlate.rkt | 4 +- lazy/tensors/0-lazy.rkt | 90 ++++----- lazy/tensors/1-reflect.rkt | 9 +- lazy/tensors/A-equality.rkt | 6 +- lazy/tensors/B-test-programs.rkt | 212 ++++++++++++++-------- lazy/tensors/c0-ast.rkt | 54 +++--- lazy/tensors/c1-racket-runtime.rkt | 59 +++--- lazy/tensors/c2-interpreter.rkt | 28 +-- lazy/tensors/c3-compiler.rkt | 86 +++++---- lazy/tensors/test/test-1-reflect.rkt | 29 +-- lazy/tensors/test/test-c2-interpreter.rkt | 14 +- lazy/tensors/test/test-c3-compiler.rkt | 34 ++-- 25 files changed, 908 insertions(+), 438 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index c23a9ba..cae86b0 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../tensors.rkt") ;;---------------------------- @@ -52,7 +53,7 @@ (hash-set σ d (+-ρ z g))))) (define +-ρ - (ext2-ρ + 0 0)) + (ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0)) ;;---------------------------- ;; Reverse-mode AD @@ -111,7 +112,7 @@ ((dual? v) (trace-print (ρ v) port)) (else (fprintf port "~a~%" v))))) -(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) +(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s)) (include "test/test-A-autodiff.rkt") diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 94796d3..e14ffab 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -1,21 +1,32 @@ #lang racket +(require (only-in "../../accelerated-tensors/ext-impl.rkt" + new-vec + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2)) (require "../tensors.rkt") (require "A-autodiff.ss") -(struct prim (ρ-fn ∇-fn shape-fn signature expects-prealloc? proc) +(struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc) #:property prop:procedure (λ (this . args) (apply (prim-proc this) args))) +;;TODO: Add new ast nodes for the 4 forces being done in the four preallocated->functional-* functions (define prim1 (let ((id 0)) - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) (let ((prim-sign (string-append "p1" (~r id #:base 16)))) (set! id (add1 id)) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da) - (prim1-dual ρ-fn ∇-fn da))))))) + (prim1-dual (if #;#f expects-prealloc? (preallocated->functional-1-ρ ρ-fn shape) ρ-fn) + (if #;#f expects-prealloc? (preallocated->functional-1-∇ ∇-fn shape) ∇-fn) + da))))))) +;; TODO: Convert the use of force* into the construction of an AST so that we +;; don't prematurely trigger computation. (define prim1-dual (λ (ρ-fn ∇-fn da) (let ((ra (ρ da))) @@ -27,12 +38,14 @@ (define prim2 (let ((id 0)) - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) (let ((prim-sign (string-append "p2" (~r id #:base 16)))) (set! id (add1 id)) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da db) - (prim2-dual ρ-fn ∇-fn da db))))))) + (prim2-dual (if expects-prealloc? (preallocated->functional-2-ρ ρ-fn shape) ρ-fn) + (if expects-prealloc? (preallocated->functional-2-∇ ∇-fn shape) ∇-fn) + da db))))))) (define prim2-dual (λ (ρ-fn ∇-fn da db) @@ -45,6 +58,48 @@ (let ((σ-hat ((κ da) da ga σ))) ((κ db) db gb σ-hat))))))))) +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define preallocated->functional-1-ρ + (λ (ρ-fn shape-fn) + (λ (ra) + (force*1 ra + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn)))))) + +(define preallocated->functional-1-∇ + (λ (∇-fn shape-fn) + (λ (ra z) + (force*2 + (λ () + (values ra z)) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn)))))) + +(define preallocated->functional-2-ρ + (λ (ρ-fn shape-fn) + (λ (ra rb) + (force*2 + (λ () + (values ra rb)) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn)))))) + +(define preallocated->functional-2-∇ + (λ (∇-fn shape-fn) + (λ (ra rb z) + (force*2 + (λ () + (values ra rb)) + (λ (ra rb) + (force*1 + z + (λ (z) + (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn)))))))) + ;;---------------------------- ;; Dualized tensor op creators ;;---------------------------- @@ -53,10 +108,12 @@ (unless (prim? f) (error 'ext1-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim1 - (ext1-ρ (prim-ρ-fn f) n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) - (ext1-∇ (prim-∇-fn f) n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) + (ext1-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext1-∇ (prim-∇-fn f) (prim-∇-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) (prim-shape-fn f) #f))) @@ -65,10 +122,12 @@ (unless (prim? f) (error 'ext2-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim2 - (ext2-ρ (prim-ρ-fn f) m n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) - (ext2-∇ (prim-∇-fn f) m n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) + (ext2-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext2-∇ (prim-∇-fn f) (prim-∇-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) (prim-shape-fn f) #f))) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 270083d..b39f25a 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -3,7 +3,7 @@ (require "A-autodiff.rkt") (require "../tensors/0-lazy.rkt") (require "../tensors/1-reflect.rkt") -(require (except-in "../../flat-tensors/ext-impl.rkt" scalarize)) +(require (except-in "../../accelerated-tensors/ext-impl.rkt" scalarize)) (define max-tensor-print-length (make-parameter 5)) diff --git a/lazy/autodiff/test/test-E-print.rkt b/lazy/autodiff/test/test-E-print.rkt index 91fde6c..092f78b 100644 --- a/lazy/autodiff/test/test-E-print.rkt +++ b/lazy/autodiff/test/test-E-print.rkt @@ -18,54 +18,54 @@ deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) - (check-equal? (make-printable long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable deep-tensor 3) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...))) (check-equal? (make-printable deeper-tensor 3) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...))) (parameterize ((max-tensor-print-length 3)) - (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) (list - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...)))))) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt index 2049b41..72b33f9 100644 --- a/lazy/ext-ops/A-scalar-ops.rkt +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -1,49 +1,108 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) (require "../autodiff.rkt") +(define +-0-0-ρ-acc + (λ (a b) + "@{a}+@{b}")) + (define +-0-0 (prim2 + + +-0-0-ρ-acc + (λ (a b z) + (values z z)) (λ (a b z) (values z z)))) +(define --0-0-ρ-acc + (λ (a b) + "@{a}-@{b}")) + (define --0-0 (prim2 - + --0-0-ρ-acc + (λ (a b z) + (values z (- z))) (λ (a b z) - (values z (- z))))) + (values z "(- @{z})")))) + +(define *-0-0-ρ-acc + (λ (a b) + "@{a}*@{b}")) (define *-0-0 (prim2 * + *-0-0-ρ-acc + (λ (a b z) + (values (* b z) (* a z))) (λ (a b z) - (values (* b z) (* a z))))) + (values "@{b}*@{z}" "@{a}*@{z}")))) + +(define /-0-0-ρ-acc + (λ (a b) + "@{a}/@{b}")) (define /-0-0 (prim2 / - (λ (a b z) - (values (* z (/ 1 b)) - (* z (/ (- a) (* b b))))))) + /-0-0-ρ-acc + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))) + (λ (a b z) + (values "(@{z} * (1 / @{b}))" + "(@{z} * ((- @{a}) / (@{b} * @{b})))")))) + +(define expt-0-0-ρ-acc + (λ (a b) + "pow(@{a}, @{b})")) (define expt-0-0 (prim2 expt - (λ (a b z) - (values (* z (* b (expt a (- b 1)))) - (* z (* (expt a b) (log a))))))) + expt-0-0-ρ-acc + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))) + (λ (a b z) + (values "(@{z} * (@{b} * pow(@{a}, (@{b} - 1))))" + "(@{z} * (pow(@{a}, @{b}) * log(@{a})))")))) + +(define exp-0-ρ-acc + (λ (a) + "exp(@{a})")) (define exp-0 (prim1 exp - (λ (a z) - (* z (exp a))))) + exp-0-ρ-acc + (λ (a z) + (* z (exp a))) + (λ (a z) + "(@{z} * exp(@{a}))"))) + +(define log-0-ρ-acc + (λ (a) + "log(@{a})")) (define log-0 (prim1 log - (λ (a z) - (* z (/ 1 a))))) + log-0-ρ-acc + (λ (a z) + (* z (/ 1 a))) + (λ (a z) + "(@{z} * (1 / @{a}))"))) + +(define sqrt-0-ρ-acc + (λ (a) + "sqrt(@{a})")) (define sqrt-0 (prim1 sqrt - (λ (x z) - (/ z (* 2 (sqrt x)))))) + sqrt-0-ρ-acc + (λ (x z) + (/ z (* 2 (sqrt x)))) + (λ (x z) + "(@{z} / (2 * sqrt(@{x})))"))) (define abs-0-ρ (λ (x) @@ -51,14 +110,22 @@ ((< x 0) (* -1 x)) (else x)))) +(define abs-0-ρ-acc + (λ (x) + "fabs(@{x})")) + (define abs-0-∇ (λ (x z) (cond ((< x 0) (- z)) (else z)))) +(define abs-0-∇-acc + (λ (x z) + "sign(@{x}) * @{z}")) + (define abs-0 - (prim1 abs-0-ρ abs-0-∇)) + (prim1 abs-0-ρ abs-0-ρ-acc abs-0-∇ abs-0-∇-acc)) (define rectify-0-ρ (λ (s) @@ -66,17 +133,25 @@ ((< s 0.0) 0.0) (else s)))) +(define rectify-0-ρ-acc + (λ (s) + "fmax(0.0f, @{s})")) + (define rectify-0-∇ (λ (s z) (cond ((< s 0.0) 0.0) (else z)))) +(define rectify-0-∇-acc + (λ (s z) + "step(0, @{s}) * @{z}")) + (define rectify-shape (λ (s) s)) (define rectify-0 - (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + (prim1 rectify-0-ρ rectify-0-ρ-acc rectify-0-∇ rectify-0-∇-acc rectify-shape)) ;;------------------------------------ ;; differentiable extended functions. @@ -102,32 +177,29 @@ ;; non-differentiable extended functions. ;;------------------------------------ -(define *-ρ (ext2-ρ * 0 0)) -(define +-ρ (ext2-ρ + 0 0)) -(define --ρ (ext2-ρ - 0 0)) -(define /-ρ (ext2-ρ / 0 0)) -(define expt-ρ (ext2-ρ expt 0 0)) - -(define exp-ρ (ext1-ρ exp 0)) -(define log-ρ (ext1-ρ log 0)) -(define abs-ρ (ext1-ρ abs-0-ρ 0)) -(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) +(define *-ρ (ext2-ρ * *-0-0-ρ-acc 0 0)) +(define +-ρ (ext2-ρ + +-0-0-ρ-acc 0 0)) +(define --ρ (ext2-ρ - --0-0-ρ-acc 0 0)) +(define /-ρ (ext2-ρ / /-0-0-ρ-acc 0 0)) +(define expt-ρ (ext2-ρ expt expt-0-0-ρ-acc 0 0)) -(define sqrt-ρ - (λ (a) - (expt-ρ a 1/2))) +(define exp-ρ (ext1-ρ exp exp-0-ρ-acc 0)) +(define log-ρ (ext1-ρ log log-0-ρ-acc 0)) +(define abs-ρ (ext1-ρ abs-0-ρ abs-0-ρ-acc 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ rectify-0-ρ-acc 0)) +(define sqrt-ρ (ext1-ρ sqrt sqrt-0-ρ-acc 0)) (define sqr-ρ (λ (x) (*-ρ x x))) (define zeroes-ρ - (ext1-ρ (λ (_) 0.0) 0)) + (ext1-ρ (λ (_) 0.0) (λ (_) "0.0") 0)) (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 - exp-0 log-0 sqrt-0 abs-0 rectify-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 d+ d- d* d/ d-expt d-exp d-log d-abs diff --git a/lazy/ext-ops/B-comparators.rkt b/lazy/ext-ops/B-comparators.rkt index 7fcb184..3db8e0f 100644 --- a/lazy/ext-ops/B-comparators.rkt +++ b/lazy/ext-ops/B-comparators.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../autodiff.rkt") ;;---------------------------- @@ -24,7 +25,7 @@ (comparator >)) (define >=-0-0 - (comparator >=)) + (comparator >)) ;;---------------------------- ;; Tensorized comparators @@ -37,6 +38,11 @@ ((f (ρ da) (ρ db)) 1.0) (else 0.0))))) +(define comparator-ρ-acc + (λ (f) + (λ (a b) + "@{a} @{f} @{b}"))) + (define comparator-∇ (λ (f) (λ (da db z) @@ -44,40 +50,48 @@ ((f (ρ da) (ρ db)) (values z z)) (else (values 0.0 0.0)))))) +(define comparator-∇-acc + (λ (f) + (λ (a b z) + (let ((bool "@{a} @{f} @{b}")) + (values "@{bool}*@{z}" "@{bool}*@{z}"))))) + (define comparator-shape (λ (f) (λ (sa sb) sa))) (define comparator-prim - (λ (f) - (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + (λ (f f-acc) + (prim2 (comparator-ρ f) (comparator-ρ-acc f-acc) + (comparator-∇ f) (comparator-∇-acc f-acc) + (comparator-shape f)))) (define extended-comparator - (λ (f) - (ext2 (comparator-prim f) 0 0))) + (λ (f f-acc) + (ext2 (comparator-prim f f-acc) 0 0))) (define =-1 - (extended-comparator =)) + (extended-comparator = "==")) (define <-1 - (extended-comparator <)) + (extended-comparator < "<")) (define >-1 - (extended-comparator >)) + (extended-comparator > ">")) (define <=-1 - (extended-comparator <=)) + (extended-comparator <= "<=")) (define >=-1 - (extended-comparator >=)) + (extended-comparator >= ">=")) (define != (λ (a b) (not (= a b)))) (define !=-1 - (extended-comparator !=)) + (extended-comparator != "!=")) (include "test/test-B-comparators.rkt") diff --git a/lazy/ext-ops/C-star-2-1.rkt b/lazy/ext-ops/C-star-2-1.rkt index 629b1ba..fab49f2 100644 --- a/lazy/ext-ops/C-star-2-1.rkt +++ b/lazy/ext-ops/C-star-2-1.rkt @@ -1,5 +1,7 @@ #lang racket +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") (require (only-in "../tensors.rkt" ext2-ρ)) (require "../autodiff.rkt") @@ -8,35 +10,69 @@ v1 i1 stride1 v-out i-out stride-out) (for ([i (in-range 0 stride-out)]) - (vector-set! v-out (+ i-out i) - (* (vector-ref v0 (+ i0 i)) - (vector-ref v1 (+ i1 (modulo i stride1)))))))) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-ρ-acc + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + #< v max) (values v (+ (- i i0) 0.0))) (else (values max max-i)))))))) +(define argmax-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< max) { + max = v; + max_i = i - @{i0} + 0.0; + } + } + @{v-out}[@{i-out}] = max_i; +EOF + )) + (define argmax-1-∇ (λ (g0 v0 i0 stride0 vz iz stride-z) - (let ((z (vector-ref vz iz))) + (let ((z (vref vz iz))) (for ([i (in-range i0 (+ i0 stride0))]) - (vector-set! g0 i 0.0))))) + (vset! g0 i 0.0))))) + +(define argmax-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< v max) v) (else max))))))) +(define max-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< v max) (values v (- i i0))) (else (values max max-i)))))))) +(define max-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< max) { + max = v; + max_i = i - @{i0}; + } + } + for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { + if(i == @{i0}+max_i) { + @{g0}[i] += z; + } else { + @{g0}[i] += 0.0; + } + } +EOF + )) + (define max-shape (λ (st) (cdr st))) (define max-1 - (prim1 max-1-ρ max-1-∇ max-shape #t)) + (prim1 max-1-ρ max-1-ρ-acc max-1-∇ max-1-∇-acc max-shape #t)) (define d-max (ext1 max-1 1)) (define max-ρ - (ext1-ρ max-1-ρ 1 max-shape #t)) + (ext1-ρ max-1-ρ max-1-ρ-acc 1 max-shape #t)) (include "test/test-F-max.rkt") diff --git a/lazy/ext-ops/G-correlate.rkt b/lazy/ext-ops/G-correlate.rkt index cfe156b..14c0cb6 100644 --- a/lazy/ext-ops/G-correlate.rkt +++ b/lazy/ext-ops/G-correlate.rkt @@ -1,5 +1,7 @@ #lang racket +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") (require (only-in "../tensors.rkt" ext2-ρ len)) (require "../autodiff.rkt") @@ -17,17 +19,39 @@ (let* ((i1-min (- i1 (modulo i1 nd))) (i1-max (+ i1-min nd))) (for ((i (in-range 0 b))) - (vector-set! v-out (+ i-out i) + (vset! v-out (+ i-out i) (for/fold ([sum 0.0]) ([j (in-range 0 md)]) (let ((ai (+ i0 (* i md) j)) (bi (- (+ i1 j) qd))) (cond ((and (>= bi i1-min) (< bi i1-max)) - (let ((a (vector-ref v0 ai)) - (b (vector-ref v1 bi))) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) (+ sum (* a b)))) (else sum)))))))))) +(define correlate-3-1-ρ-acc + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + #<= i1_min && bi < i1_max) { + sum += @{v0}[ai] * @{v1}[bi]; + } + } + @{v-out}[@{i-out}+i] = sum; + } +EOF + ))) + (define correlate-3-1-∇ (λ (nd md qd) (λ (g0 g1 @@ -37,17 +61,55 @@ (let* ((i1-min (- i1 (modulo i1 nd))) (i1-max (+ i1-min nd))) (for ((i (in-range 0 b))) - (let ((z (vector-ref vz (+ iz i)))) + (let ((z (vref vz (+ iz i)))) (for ([j (in-range 0 md)]) (let ((ai (+ i0 (* i md) j)) (bi (- (+ i1 j) qd))) (when (and (>= bi i1-min) (< bi i1-max)) - (let ((a (vector-ref v0 ai)) - (b (vector-ref v1 bi))) - (vector-set! g0 ai - (+ (vector-ref g0 ai) (* z b))) - (vector-set! g1 bi - (+ (vector-ref g1 bi) (* z a))))))))))))) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-3-1-∇-acc + (λ (nd md qd) + (λ (g + v0 i0 bmd + v1 i1 d + vz iz b) + (values + #<= i1_min && bi < i1_max) { + @{g}[ai] += z * @{v1}[bi]; + } + } + } +EOF + + #<= i1_min && bi < i1_max) { + @{g}[bi] += z * @{v0}[ai]; + } + } + } +EOF + )))) (define correlate-shape (λ (bmd nd) @@ -57,9 +119,10 @@ (λ (nd md qd) (prim2 (correlate-3-1-ρ nd md qd) + (correlate-3-1-ρ-acc nd md qd) (correlate-3-1-∇ nd md qd) - correlate-shape - #t))) + (correlate-3-1-∇-acc nd md qd) + correlate-shape #t))) (define d-correlate (λ (bank signal) @@ -83,7 +146,7 @@ (q (/ (- m 1) 2)) ;; This is the padding. (qd (* q d)) (md (* m d))) - ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape #t) + ((ext2-ρ (correlate-3-1-ρ nd md qd) (correlate-3-1-ρ-acc nd md qd) 3 1 correlate-shape #t) bank signal)))) (define last diff --git a/lazy/ext-ops/I-flatten.rkt b/lazy/ext-ops/I-flatten.rkt index bf24773..0ef02f6 100644 --- a/lazy/ext-ops/I-flatten.rkt +++ b/lazy/ext-ops/I-flatten.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) (require (only-in "../autodiff.rkt" prim1 ext1)) @@ -7,10 +8,30 @@ (λ (t) (reshape (flatten-shape (shape t)) t))) +(define flatten-2-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<preallocated-1-ρ f base-shape shape-fn-out))) - (tpmake-ext1-ρ flat-f prim-sign m shape-fn tp out-shape))])))))) + (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape shape-fn-out))) + (tpmake-ext1-ρ flat-f flat-f-acc prim-sign m shape-fn tp out-shape))])))))) ;; See comment for tp-ext1-ρ (define tp-ext2-ρ (let ((id -1)) - (λ (f m n + (λ (f f-acc m n [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin @@ -216,19 +217,21 @@ instructions refering to the same gensym variable [(and (tpromise? tp-t) (tpromise? tp-u) (null? (tpromise-shape tp-t)) (null? (tpromise-shape tp-u))) - (tpmake-ext2-ρ-scalar f prim-sign tp-t tp-u sf-out)] + (tpmake-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u sf-out)] [expects-prealloc? (tpmake-ext2-ρ tp-t tp-u - f prim-sign m n shape-fn + f f-acc prim-sign m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out)))] [else (let ((flat-f (functional->preallocated-2-ρ - f sf0 sf1 sf-out))) + f sf0 sf1 sf-out)) + (flat-f-acc (functional->preallocated-2-ρ-acc + f-acc sf0 sf1 sf-out))) (tpmake-ext2-ρ tp-t tp-u - flat-f prim-sign m n shape-fn + flat-f flat-f-acc prim-sign m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))))) @@ -238,7 +241,7 @@ instructions refering to the same gensym variable ;; See comment for tp-ext1-ρ (define tp-ext1-∇ (let ((id -1)) - (λ (f m + (λ (f f-acc m [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin @@ -249,58 +252,63 @@ instructions refering to the same gensym variable (cond ((number? tp) (f tp zp)) (expects-prealloc? - (tpmake-ext1-∇ tp zp f prim-sign m shape-fn (tp-shape tp))) + (tpmake-ext1-∇ tp zp f f-acc prim-sign m shape-fn (tp-shape tp))) (else (let* ((in-shape (tpromise-shape tp)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tpmake-ext1-∇ tp zp flat-f prim-sign m shape-fn (tp-shape tp))))))))) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc f-acc base-shape out-shape))) + (tpmake-ext1-∇ tp zp flat-f flat-f-acc prim-sign m shape-fn (tp-shape tp))))))))) ;; See comment for tp-ext1-ρ (define tp-ext2-∇ (let ((id -1)) - (λ (f m n + (λ (f f-acc m n [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin (set! id (add1 id)) (string-append "ne2" (~r id #:base 16)))]) (let ((tp-f - (λ (f tp-t tp-u tp-z) - (tp-d-ext2^ f prim-sign m n shape-fn + (λ (f f-acc tp-t tp-u tp-z) + (tp-d-ext2^ f f-acc prim-sign m n shape-fn tp-t tp-u tp-z)))) (λ (tp-t tp-u tp-z) (cond (expects-prealloc? - (tp-f f tp-t tp-u tp-z)) + (tp-f f f-acc tp-t tp-u tp-z)) [else (let* ((t-shape (min-shape m (tp-shape tp-t))) (u-shape (min-shape n (tp-shape tp-u))) (out-shape (shape-fn t-shape u-shape)) (flat-f (functional->preallocated-2-∇ - f t-shape u-shape out-shape))) - (tp-f flat-f tp-t tp-u tp-z))])))))) + f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc + f-acc t-shape u-shape out-shape))) + (tp-f flat-f flat-f-acc tp-t tp-u tp-z))])))))) (define tp-d-ext2^ - (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (λ (fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) (values - (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 0 (tp-shape tp-t0)) - (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 1 (tp-shape tp-t1)))))) (define tp-rank (λ (tp) - (flat:len (tp-shape tp)))) + (acc:len (tp-shape tp)))) (define tp-reshape (λ (s tp) (cond - ((= (flat:size-of s) (flat:size-of (tpromise-shape tp))) + ((and (tpromise? tp) (= (acc:size-of s) (acc:size-of (tpromise-shape tp)))) (tpmake-reshape tp s)) - (else (error 'shape-error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) + [(and (acc:flat? tp) (= (acc:size-of s) (acc:size-of (acc:shape tp)))) + (acc:reshape s tp)] + (else (error 'shape-error "Cannot reshape ~a to ~a~%" tp s))))) (define tensor? (lambda (tp) @@ -311,9 +319,9 @@ instructions refering to the same gensym variable (provide start-vector-manager vector-manager-report) (provide (rename-out - (flat:len len) - (flat:ref ref) - (flat:refr refr))) + (acc:len len) + (acc:ref ref) + (acc:refr refr))) (provide tensor tpromise? (rename-out @@ -335,4 +343,4 @@ instructions refering to the same gensym variable (tp-rank rank) (tp-shape shape) (tp-reshape reshape) - (flat:size-of size-of))) + (acc:size-of size-of))) diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt index df6575f..bc2a0c9 100644 --- a/lazy/tensors/1-reflect.rkt +++ b/lazy/tensors/1-reflect.rkt @@ -1,6 +1,6 @@ #lang racket -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (require "c0-ast.rkt") (require (only-in "c3-compiler.rkt" compiler-cache @@ -42,11 +42,10 @@ (cond [(and (tpromise? tp) (null? (tpromise-shape tp))) (tp-scalarize (↓ tp))] - [(and (flat:flat? tp) (null? (flat:flat-shape tp))) - (vector-ref (flat:flat-store tp) 0)] + [(and (acc:flat? tp) (null? (acc:flat-shape tp))) + (vector-ref (acc:flat-store tp) 0)] [else tp]))) -;; TODO: these force functions will be moved to the openCL runtime (define force*1 (λ (t f) (f (↓ t)))) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt index 7c420ca..8cfa8ba 100644 --- a/lazy/tensors/A-equality.rkt +++ b/lazy/tensors/A-equality.rkt @@ -1,11 +1,11 @@ #lang racket (require "1-reflect.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define tp-tensor-equal? (λ (tp-actual tp-expected) - (flat:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) + (acc:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) (require rackunit) (define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) @@ -13,6 +13,6 @@ (include "test/test-A-equality.rkt") (provide (rename-out - (flat:tolerance tolerance) + (acc:tolerance tolerance) (tp-tensor-equal? tensor-equal?) (tp-check-tensor-equal? check-tensor-equal?))) diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt index 450de2f..f301c73 100644 --- a/lazy/tensors/B-test-programs.rkt +++ b/lazy/tensors/B-test-programs.rkt @@ -1,7 +1,8 @@ #lang racket +(require string-interpolation) (require "0-lazy.rkt") -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define make-tref-test-program (λ (t) @@ -22,29 +23,29 @@ 'tensor-r1-0 (test-program-data (λ () (tensor 1 2 3)) - (eval-res-1 (flat:tensor 1 2 3))) + (eval-res-1 (acc:tensor 1 2 3))) 'tensor-r1-1 (test-program-data (λ () (tensor 1 2 3 4 5)) - (eval-res-1 (flat:tensor 1 2 3 4 5))) + (eval-res-1 (acc:tensor 1 2 3 4 5))) 'tensor-r1-2 (test-program-data (λ () (tensor 3.0 4.0 5.0)) - (eval-res-1 (flat:tensor 3.0 4.0 5.0))) + (eval-res-1 (acc:tensor 3.0 4.0 5.0))) 'tensor-r2-0 (test-program-data (λ () (tensor (tensor 1 2 3) (tensor 4 5 6))) - (eval-res-1 (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6)))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) (acc:tensor 4 5 6)))) 'tensor-r2-1 (test-program-data (λ () (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) (eval-res-1 - (flat:reshape '(2 3) (flat:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) + (acc:reshape '(2 3) (acc:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) 'build-tensor-r1-0 (test-program-data (λ () (build-tensor '(6) (λ (i) (* 3.0 (car i))))) - (eval-res-1 (flat:build-tensor '(6) + (eval-res-1 (acc:build-tensor '(6) (λ (i) (* 3.0 (car i)))))) 'build-tensor-r2-0 (test-program-data (λ () @@ -52,7 +53,7 @@ (λ (i) (match-define `(,x ,y) i) (* 2.0 (+ (* x 6) y))))) - (eval-res-1 (flat:build-tensor '(5 6) + (eval-res-1 (acc:build-tensor '(5 6) (λ (i) (match-define `(,x ,y) i) (* 2.0 (+ (* x 6) y)))))) @@ -62,7 +63,7 @@ (λ (i) (match-define `(,x ,y) i) (* 3.0 (+ (* x 6) y))))) - (eval-res-1 (flat:build-tensor '(3 6) + (eval-res-1 (acc:build-tensor '(3 6) (λ (i) (match-define `(,x ,y) i) (* 3.0 (+ (* x 6) y)))))) @@ -72,7 +73,7 @@ (λ (i) (match-define `(,x ,y ,z) i) (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) - (eval-res-1 (flat:build-tensor + (eval-res-1 (acc:build-tensor '(2 3 4) (λ (i) (match-define `(,x ,y ,z) i) @@ -83,7 +84,7 @@ (λ (i) (match-define `(,x ,y ,z) i) (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) - (eval-res-1 (flat:build-tensor + (eval-res-1 (acc:build-tensor '(3 5 6) (λ (i) (match-define `(,x ,y ,z) i) @@ -98,7 +99,7 @@ (let ((tp (trefs (get-test-program 'tensor-r1-0) '(0 2)))) (+-ρ tp tp))) - (eval-res-1 (flat:tensor 2 6))) + (eval-res-1 (acc:tensor 2 6))) 'built-tensor (test-program-data (λ () (let ((test-build-shape '(4 3))) @@ -109,16 +110,16 @@ (+ (* (sub1 (car test-build-shape)) row) column)))))) - (eval-res-1 (flat:tensor (flat:tensor 0 1 2) - (flat:tensor 3 4 5) - (flat:tensor 6 7 8) - (flat:tensor 9 10 11)))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 3 4 5) + (acc:tensor 6 7 8) + (acc:tensor 9 10 11)))) 'multi-built-tensor (test-program-data (λ () (+-ρ (get-test-program 'build-tensor-r2-0) (tref (get-test-program 'build-tensor-r3-1) 0))) - (eval-res-1 ((flat:ext2-ρ * 0 0) 2 - (flat:build-tensor + (eval-res-1 ((acc:ext2-ρ * (λ (a b) "@{a} * @{b}") 0 0) 2 + (acc:build-tensor '(5 6) (λ (i) (match-define `(,x ,y) i) @@ -134,45 +135,45 @@ 'tcomp-list->tensor (test-program-data (λ () (make-list->tensor-test-program '(5 6 7 8))) - (eval-res-1 (flat:tensor 5 6 7 8))) + (eval-res-1 (acc:tensor 5 6 7 8))) 'tcomp-nested-list->tensor (test-program-data (λ () (list->tensor `(,(get-test-program 'tensor-r1-0) ,(get-test-program 'tensor-r1-0) ,(get-test-program 'tensor-r1-0)))) - (eval-res-1 (flat:tensor - (flat:tensor 1 2 3) - (flat:tensor 1 2 3) - (flat:tensor 1 2 3)))) + (eval-res-1 (acc:tensor + (acc:tensor 1 2 3) + (acc:tensor 1 2 3) + (acc:tensor 1 2 3)))) 'tcomp-trefs (test-program-data (λ () (trefs (get-test-program 'built-tensor) '(0 2))) - (eval-res-1 (flat:tensor (flat:tensor 0 1 2) - (flat:tensor 6 7 8)))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 6 7 8)))) 'tcomp-reshape (test-program-data (λ () (reshape '(3 2 1) (trefs (get-test-program 'built-tensor) '(1 3)))) - (eval-res-1 (flat:tensor (flat:tensor (flat:tensor 3) - (flat:tensor 4)) - (flat:tensor (flat:tensor 5) - (flat:tensor 9)) - (flat:tensor (flat:tensor 10) - (flat:tensor 11))))) + (eval-res-1 (acc:tensor (acc:tensor (acc:tensor 3) + (acc:tensor 4)) + (acc:tensor (acc:tensor 5) + (acc:tensor 9)) + (acc:tensor (acc:tensor 10) + (acc:tensor 11))))) 'sum (test-program-data (λ () (sum (get-test-program 'tensor-r2-0))) - (eval-res-1 (flat:tensor 6.0 15.0))) + (eval-res-1 (acc:tensor 6.0 15.0))) 'sum-nested (test-program-data (λ () (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) - (eval-res-1 (flat:tensor 4.0 6.0 5.0))) + (eval-res-1 (acc:tensor 4.0 6.0 5.0))) 'id (test-program-data (λ () (id-ρ (get-test-program 'tensor-r2-0))) - (eval-res-1 (flat:tensor (flat:tensor 1 2 3) - (flat:tensor 4 5 6)))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) + (acc:tensor 4 5 6)))) 'id-scalar (test-program-data (λ () (id-ρ (sum (tensor 4 5 6)))) @@ -185,9 +186,9 @@ (λ () (*-ρ (get-test-program 'build-tensor-r3-0) (get-test-program 'build-tensor-r3-0))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(2 3 4) - (flat:tensor + (acc:tensor 0 4 16 36 64 100 144 196 256 324 400 484 @@ -198,9 +199,9 @@ (λ () (*-2-1 (get-test-program 'build-tensor-r2-0) (get-test-program 'build-tensor-r1-0))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(5 6) - (flat:tensor + (acc:tensor 0 6.0 24.0 54.0 96.0 150.0 0 42.0 96.0 162.0 240.0 330.0 0 78.0 168.0 270.0 384.0 510.0 @@ -210,9 +211,9 @@ (λ () (*-2-1 (get-test-program 'build-tensor-r3-1) (get-test-program 'build-tensor-r2-1))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(3 5 6) - (flat:tensor + (acc:tensor 0 6.0 24.0 54.0 96.0 150.0 0 42.0 96.0 162.0 240.0 330.0 0 78.0 168.0 270.0 384.0 510.0 @@ -238,14 +239,14 @@ 'tcomp-dsqr-r1 (test-program-data (λ () (d-sqr r1-td (one-like r1-td))) - (eval-res-1 (flat:tensor 6.0 8.0 10.0))) + (eval-res-1 (acc:tensor 6.0 8.0 10.0))) 'gsqr (test-program-data (λ () (let ([r2-td (get-test-program 'tensor-r2-1)]) (d-sqr r2-td (one-like r2-td)))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(2 3) - (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + (acc:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) 'g+ (test-program-data (λ () (d+ 2.0 3.0 1.0)) @@ -253,42 +254,42 @@ 'g-twice (test-program-data (λ () (d+ r1-td r1-td (one-like r1-td))) - (eval-res-2 (flat:tensor 1.0 1.0 1.0) - (flat:tensor 1.0 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) 'g+-r1-r2 (test-program-data (λ () (let ((r2-td (get-test-program 'tensor-r2-1))) (d+ r1-td r2-td (one-like r2-td)))) - (eval-res-2 (flat:tensor 2.0 2.0 2.0) - (flat:reshape + (eval-res-2 (acc:tensor 2.0 2.0 2.0) + (acc:reshape '(2 3) - (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + (acc:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) 'g* (test-program-data (λ () (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0))) - (eval-res-2 (flat:tensor 1.0 2.0 3.0) - (flat:tensor 2.0 3.0 4.0))) + (eval-res-2 (acc:tensor 1.0 2.0 3.0) + (acc:tensor 2.0 3.0 4.0))) 'gsum-r1 (test-program-data (λ () (sum-∇ (tensor 2.0 3.0 4.0) 1.0)) - (eval-res-1 (flat:tensor 1.0 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0 1.0))) 'gsum-r2 (test-program-data (λ () (sum-∇ (tensor (tensor 2.0 3.0 4.0) (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0))) - (eval-res-1 (flat:tensor (flat:tensor 2.0 2.0 2.0) - (flat:tensor 1.0 1.0 1.0)))) + (eval-res-1 (acc:tensor (acc:tensor 2.0 2.0 2.0) + (acc:tensor 1.0 1.0 1.0)))) 'gs2-r1 (test-program-data (λ () (s2-∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0))) - (eval-res-2 (flat:tensor 1.0 1.0 1.0) - (flat:tensor 1.0 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) 'gs2-r3 (test-program-data (λ () (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) @@ -309,20 +310,20 @@ (tensor 1.0 1.0)) (tensor (tensor 1.0 1.0) (tensor 1.0 1.0))))) - (eval-res-2 (flat:reshape '(3 2 3) - (flat:list->tensor (make-list 18 1.0))) - (flat:reshape '(3 2 3) - (flat:list->tensor (make-list 18 1.0))))) + (eval-res-2 (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))) + (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))))) 'env-flat-scalar (test-program-data (λ () ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) (list (tensor 1.0) 3.0))) - (eval-res-1 (flat:tensor 3.0))) + (eval-res-1 (acc:tensor 3.0))) 'common-subexpression (test-program-data (λ () (let ((t (tref (tensor 1 2 3) 0))) (tensor t t))) - (eval-res-1 (flat:tensor 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0))) 'nested-common-subexpression (test-program-data (λ () (let ((t1 (tref (tensor (tensor 1 2 3) @@ -330,7 +331,7 @@ 0))) (let ((t0 (tref t1 0))) (tensor t0 t0)))) - (eval-res-1 (flat:tensor 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0))) )) (define get-test-program @@ -345,13 +346,24 @@ (vset! out-v iₒ (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) (+ sum (vref in-v i)))))) - -(define sum (ext1-ρ sum-f 1 (λ (s) '()) #t)) +(define sum-f-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural)))) -(struct tcomp-ext1-ρ-scalar tcomp (f sign tp) #:transparent) -(struct tcomp-ext1-ρ tcomp (f sign m shape-fn tp) #:transparent) -(struct tcomp-ext2-ρ-scalar tcomp (f sign tp-t tp-u) #:transparent) -(struct tcomp-ext2-ρ tcomp (tp-t tp-u f sign m n shape-fn) #:transparent) -(struct tcomp-ext1-∇ tcomp (tp zp f sign m shape-fn) #:transparent) -(struct tcomp-ext2-∇ tcomp (fᵈ +(struct tcomp-ext1-ρ-scalar tcomp (f f-acc sign tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f f-acc sign m shape-fn tp) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f f-acc sign tp-t tp-u) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f f-acc sign m n shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f f-acc sign m shape-fn) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) @@ -98,7 +98,7 @@ (define gdst-trefs (λ (tp i-lst) - (box (list (tpromise-dst tp) (flat:list->tensor i-lst))))) + (box (list (tpromise-dst tp) (acc:list->tensor i-lst))))) (define gdst-ext2-∇ (λ (tp-t0 tp-t1 tp-z) @@ -216,24 +216,24 @@ (gs-trefs tp)))) (define tpmake-ext1-ρ-scalar - (λ (f signature tp shape) - (tpromise (tcomp-ext1-ρ-scalar f signature tp) shape + (λ (f f-acc prim-sign tp shape) + (tpromise (tcomp-ext1-ρ-scalar f f-acc prim-sign tp) shape (box (list (tpromise-dst tp))) - (gs-ext1-ρ-scalar signature tp)))) + (gs-ext1-ρ-scalar prim-sign tp)))) (define tpmake-ext1-ρ - (λ (f signature m shape-fn tp shape) - (tpromise (tcomp-ext1-ρ f signature m shape-fn tp) + (λ (f f-acc prim-sign m shape-fn tp shape) + (tpromise (tcomp-ext1-ρ f f-acc prim-sign m shape-fn tp) shape (box (list (tpromise-dst tp))) - (gs-ext1-ρ signature m tp)))) + (gs-ext1-ρ prim-sign m tp)))) (define tpmake-ext2-ρ-scalar - (λ (f signature tp-t tp-u shape) - (tpromise (tcomp-ext2-ρ-scalar f signature tp-t tp-u) + (λ (f f-acc prim-sign tp-t tp-u shape) + (tpromise (tcomp-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u) shape (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) - (gs-ext2-ρ-scalar signature tp-t tp-u)))) + (gs-ext2-ρ-scalar prim-sign tp-t tp-u)))) (define ensure-tpromise (λ (v) @@ -243,14 +243,14 @@ (else v)))) (define tpmake-ext2-ρ - (λ (tp-t tp-u f signature m n shape-fn shape) + (λ (tp-t tp-u f f-acc prim-sign m n shape-fn shape) (let ((tp-t (ensure-tpromise tp-t)) (tp-u (ensure-tpromise tp-u))) (tpromise - (tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) + (tcomp-ext2-ρ tp-t tp-u f f-acc prim-sign m n shape-fn) shape (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) - (gs-ext2-ρ signature m n tp-t tp-u))))) + (gs-ext2-ρ prim-sign m n tp-t tp-u))))) ;; we invoke ensure-tpromise on just zp because it's the result of calling ;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to @@ -258,25 +258,25 @@ ;; before passing it to this function, nor do we need scalar tp to be wrapped in ;; a tpromise. (define tpmake-ext1-∇ - (λ (tp zp f signature m shape-fn shape) + (λ (tp zp f f-acc prim-sign m shape-fn shape) (let ((zp (ensure-tpromise zp))) (tpromise - (tcomp-ext1-∇ tp zp f signature m shape-fn) + (tcomp-ext1-∇ tp zp f f-acc prim-sign m shape-fn) shape (box (list (tpromise-dst tp) (tpromise-dst zp))) - (gs-ext1-∇ signature m tp zp))))) + (gs-ext1-∇ prim-sign m tp zp))))) (define tpmake-ext2-∇ - (λ (fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) + (λ (fᵈ fᵈ-acc prim-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) (let ((tp-t0 (ensure-tpromise tp-t0)) (tp-t1 (ensure-tpromise tp-t1)) (tp-z (ensure-tpromise tp-z))) (tpromise - (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn + (tcomp-ext2-∇ fᵈ fᵈ-acc prim-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) shape (gdst-ext2-∇ tp-t0 tp-t1 tp-z) - (gs-ext2-∇ signature r0 r1 tp-t0 tp-t1 tp-z i))))) + (gs-ext2-∇ prim-sign r0 r1 tp-t0 tp-t1 tp-z i))))) (define tpmake-reshape (λ (tp shape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 0600e04..9fad17f 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -1,26 +1,28 @@ #lang racket -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require ffi/vector) +(require "../../impl-loader.rkt") +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (struct ext2-∇-result (res) #:mutable #:transparent) (define ext2-∇-forcer! - (λ (fᵈ r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) + (λ (fᵈ fᵈ-acc fᵈ-sign r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) (let* ((f0 (ensure-flat t0)) (f1 (ensure-flat t1)) (fz (ensure-flat z)) (s0 (flat-shape f0)) (sf0 (min-shape r0 s0)) - (stride0 (flat:size-of sf0)) + (stride0 (acc:size-of sf0)) (s1 (flat-shape t1)) (sf1 (min-shape r1 s1)) - (stride1 (flat:size-of sf1)) + (stride1 (acc:size-of sf1)) (sf-z (shape-fn sf0 sf1)) - (stride-z (flat:size-of sf-z)) + (stride-z (acc:size-of sf-z)) (v0 (flat-store f0)) (v1 (flat-store f1)) @@ -32,29 +34,22 @@ (ext2-shapes s0 s1 r0 r1 sf-z (λ (sz size-z q0 q1 strides) - (let ((g0 (new-vec (flat:size-of - s0) - 0.0)) - (g1 (new-vec (flat:size-of - s1) - 0.0))) - (for ([iz (in-range - 0 - size-z - stride-z)]) - (let-values (((i0 i1) - (idxs - strides - iz - off0 - off1))) - (fᵈ g0 g1 v0 i0 - stride0 - v1 i1 - stride1 - vz - (+ offz iz) - stride-z))) + (let ((g0 (new-vec (acc:size-of s0) 0.0)) + (g1 (new-vec (acc:size-of s1) 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) + (ext2-∇-kernel/name fᵈ-acc fᵈ-sign strides s0 s1 r0 r1 sz + (length sf-z)))) + (run-prim2-∇! kernel-code kernel-name + g0 g1 + v0 off0 (acc:size-of s0) stride0 + v1 off1 (acc:size-of s1) stride1 + vz offz size-z stride-z))) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) (when out-idx0 (data-segment-set! out-idx0 (scalarize (flat s0 g0 0)))) (when out-idx1 @@ -63,7 +58,7 @@ (define rt:trefs (λ (ft b) (cond - ((= (flat:rank b) 1) (flat:trefs ft (vector->list (flat-store b)))) + ((= (acc:rank b) 1) (acc:trefs ft (map inexact->exact (f32vector->list (flat-store b))))) (else (error 'trefs-err "~a should be a tensor¹" b))))) (define data-segment @@ -81,7 +76,7 @@ (define runtime (namespace-anchor->namespace a)) -(provide runtime flat? flat:build-tensor flat:list->tensor - flat:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! +(provide runtime flat? acc:build-tensor acc:list->tensor + acc:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index 423029c..ebd84c4 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -2,8 +2,8 @@ (require "c0-ast.rkt") (require (only-in "c1-racket-runtime.rkt" - runtime flat? flat:build-tensor flat:list->tensor - set-ext2-∇-result-res! flat:tref rt:trefs ext2-∇-result-res + runtime flat? acc:build-tensor acc:list->tensor + set-ext2-∇-result-res! acc:tref rt:trefs ext2-∇-result-res ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref)) @@ -18,14 +18,14 @@ ((tpromise? arg) (interp-tpromise arg env)) ((number? arg) arg) (else (error 'interp-list->tensor "Unexpected: ~a" arg)))))) - (flat:list->tensor eval-list))] + (acc:list->tensor eval-list))] [(tcomp-tref tp (and i (tcomp-ds-ref _))) - (flat:tref (interp-tpromise tp env) + (acc:tref (interp-tpromise tp env) (interp-tcomp i env))] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (rt:trefs (interp-tpromise tp env) (interp-tcomp b env))] - [(tcomp-ext2-∇ fᵈ _ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc f-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let ((t0-instrs (interp-tpromise tp-t0 env)) (t1-instrs (interp-tpromise tp-t1 env)) (z-instrs (interp-tpromise tp-z env))) @@ -44,28 +44,28 @@ (v (data-segment-ref index))) (cond ((eqv? v 'uncalculated) - (ext2-∇-forcer! fᵈ r0 r1 shape-fn + (ext2-∇-forcer! fᵈ fᵈ-acc f-sign r0 r1 shape-fn t0-instrs t1-instrs z-instrs out-idx0 out-idx1) (data-segment-ref index)) (else v))))] - [(tcomp-ext1-∇ tp zp f _ m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc f-sign m shape-fn) (scalarize - (flat-ext1-∇ f m shape-fn + (flat-ext1-∇ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env)) (ensure-flat (interp-tpromise zp env))))] - [(tcomp-ext2-ρ-scalar f _ tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc _ tp-t tp-u) (f (interp-tpromise tp-t env) (interp-tpromise tp-u env))] - [(tcomp-ext2-ρ tp-t tp-u f _ m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc f-sign m n shape-fn) (scalarize - (flat-ext2-ρ f m n shape-fn + (flat-ext2-ρ f f-acc m n shape-fn f-sign (ensure-flat (interp-tpromise tp-t env)) (ensure-flat (interp-tpromise tp-u env))))] - [(tcomp-ext1-ρ-scalar f _ tp) + [(tcomp-ext1-ρ-scalar f f-acc _ tp) (f (interp-tpromise tp env))] - [(tcomp-ext1-ρ f _ m shape-fn tp) + [(tcomp-ext1-ρ f f-acc f-sign m shape-fn tp) (scalarize - (flat-ext1-ρ f m shape-fn + (flat-ext1-ρ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env))))] [(tcomp-reshape s tp) (let ([interp-tp (interp-tpromise tp env)]) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 6b8437f..594eb37 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -2,8 +2,6 @@ (require "c0-ast.rkt") (require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (require (only-in "c1-racket-runtime.rkt" runtime ext2-∇-result-res set-ext2-∇-result-res!)) @@ -67,7 +65,7 @@ [(tcomp-trefs tp (tcomp-ds-ref #f)) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) @@ -80,27 +78,27 @@ ((and (eqv? i 1) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) - (values (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ + (values (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ out-ref0 out-ref1 i) (add1 ref^^^)))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) - (values (tcomp-ext1-∇ tp^ zp^ f sign m shape-fn) ref^^))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (values (tcomp-ext1-∇ tp^ zp^ f f-acc sign m shape-fn) ref^^))] + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) - (values (tcomp-ext2-ρ-scalar f sign tp-t^ tp-u^) ref^^))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (values (tcomp-ext2-ρ-scalar f f-acc sign tp-t^ tp-u^) ref^^))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) - (values (tcomp-ext2-ρ tp-t^ tp-u^ f sign m n shape-fn) ref^^))] - [(tcomp-ext1-ρ-scalar f sign tp) + (values (tcomp-ext2-ρ tp-t^ tp-u^ f f-acc sign m n shape-fn) ref^^))] + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) - (values (tcomp-ext1-ρ-scalar f sign tp^) ref^))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + (values (tcomp-ext1-ρ-scalar f f-acc sign tp^) ref^))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) - (values (tcomp-ext1-ρ f sign m shape-fn tp^) ref^))] + (values (tcomp-ext1-ρ f f-acc sign m shape-fn tp^) ref^))] [(tcomp-reshape s tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-reshape s tp^) ref^))] @@ -173,22 +171,22 @@ (cr-tpromise tp counter^ uid^)] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (cr-tpromise tp counter^ uid^)] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) (cr-tpromise tp-t1 counter-2 uid-2))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) (cr-tpromise zp counter-1 uid-1))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) (cr-tpromise tp-u counter-1 uid-1))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) (cr-tpromise tp-u counter-1 uid-1))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f _ sign tp) (cr-tpromise tp counter^ uid^)] - [(tcomp-ext1-ρ f sign m shape-fn tp) + [(tcomp-ext1-ρ f _ sign m shape-fn tp) (cr-tpromise tp counter^ uid^)] [(tcomp-reshape s tp) (cr-tpromise tp counter^ uid^)] @@ -263,7 +261,7 @@ (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (->ecs (ecs-tpromise tp-t0 counter) (λ (t0-instrs) @@ -274,11 +272,11 @@ (ecs-tpromise tp-z counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn + (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn t0-instrs t1-instrs z-instrs out0 out1 i) tc-counter-data)))))))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (->ecs (ecs-tpromise tp counter) (λ (t-instrs) @@ -286,9 +284,9 @@ (ecs-tpromise zp counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext1-∇ t-instrs z-instrs f sign m shape-fn) + (tcomp-ext1-∇ t-instrs z-instrs f f-acc sign m shape-fn) tc-counter-data)))))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (->ecs (ecs-tpromise tp-t counter) (λ (t-instrs) @@ -296,9 +294,9 @@ (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ-scalar f sign t-instrs u-instrs) + (tcomp-ext2-ρ-scalar f f-acc sign t-instrs u-instrs) tc-counter-data)))))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (->ecs (ecs-tpromise tp-t counter) (λ (t-instrs) @@ -306,18 +304,18 @@ (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ t-instrs u-instrs f sign m n shape-fn) + (tcomp-ext2-ρ t-instrs u-instrs f f-acc sign m n shape-fn) tc-counter-data)))))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (->ecs (ecs-tpromise tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f sign instrs) tc-counter-data)))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f f-acc sign instrs) tc-counter-data)))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (->ecs (ecs-tpromise tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ f sign m shape-fn instrs) tc-counter-data)))] + (inj-ecs-tcomp (tcomp-ext1-ρ f f-acc sign m shape-fn instrs) tc-counter-data)))] [(tcomp-reshape s tp) (->ecs (ecs-tpromise tp counter) @@ -389,16 +387,16 @@ ((number? t) t) (else (error 'gr-list->tensor "Unexpected: ~a" t)))) lst))) - `(flat:list->tensor (list ,@instrs-list)))] + `(acc:list->tensor (list ,@instrs-list)))] [(tcomp-tref tp (and i (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (i-instrs (gr-tcomp i))) - `(flat:tref ,instrs ,i-instrs))] + `(acc:tref ,instrs ,i-instrs))] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (b-instrs (gr-tcomp b))) `(rt:trefs ,instrs ,b-instrs))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let ((t0-instrs (gr-tpromise tp-t0)) (t1-instrs (gr-tpromise tp-t1)) (z-instrs (gr-tpromise tp-z)) @@ -409,36 +407,36 @@ [v (data-segment-ref index)]) (cond ((eqv? v 'uncalculated) - (ext2-∇-forcer! ,fᵈ ,r0 ,r1 ,shape-fn + (ext2-∇-forcer! ,fᵈ ,fᵈ-acc ,sign ,r0 ,r1 ,shape-fn ,t0-instrs ,t1-instrs ,z-instrs ,out-idx0 ,out-idx1) (data-segment-ref index)) (else v)))))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (let ((t-instrs (gr-tpromise tp)) (z-instrs (gr-tpromise zp))) `(scalarize - (flat-ext1-∇ ,f ,m ,shape-fn + (flat-ext1-∇ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,t-instrs) (ensure-flat ,z-instrs))))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (let ((t-instrs (gr-tpromise tp-t)) (u-instrs (gr-tpromise tp-u))) `(,f ,t-instrs ,u-instrs))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (let ((t-instrs (gr-tpromise tp-t)) (u-instrs (gr-tpromise tp-u))) `(scalarize - (flat-ext2-ρ ,f ,m ,n ,shape-fn + (flat-ext2-ρ ,f ,f-acc ,m ,n ,shape-fn ,sign (ensure-flat ,t-instrs) (ensure-flat ,u-instrs))))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (let ((instrs (gr-tpromise tp))) `(,f ,instrs))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let ((instrs (gr-tpromise tp))) `(scalarize - (flat-ext1-ρ ,f ,m ,shape-fn + (flat-ext1-ρ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,instrs))))] [(tcomp-reshape s tp) (let ((instrs (gr-tpromise tp))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt index ce4481e..e266e65 100644 --- a/lazy/tensors/test/test-1-reflect.rkt +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require ffi/vector) (require "0-lazy.rkt") (require "B-test-programs.rkt") @@ -14,33 +15,33 @@ ((eval-res-1 res) (let* ((tp (th)) (forced (↓ tp))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? forced res (format "Expected result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp) - (check-equal? (tpromise-shape tp) (flat:shape forced)))) + (check-equal? (tpromise-shape tp) (acc:shape forced)))) ((eval-res-2 res1 res2) (let*-values (((tp1 tp2) (th)) ((forced1) (↓ tp1)) ((forced2) (↓ tp2))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? forced1 res1 (format "Expected first result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp1) - (check-equal? (tpromise-shape tp1) (flat:shape forced1)) - (flat:check-tensor-equal? + (check-equal? (tpromise-shape tp1) (acc:shape forced1)) + (acc:check-tensor-equal? forced2 res2 (format "Expected second result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp2) - (check-equal? (tpromise-shape tp2) (flat:shape forced2)))))) + (check-equal? (tpromise-shape tp2) (acc:shape forced2)))))) (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) - (check-false (flat:flat? (tpromise-tensor test-tensor-r1-0))) - (check-true (flat:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) + (check-false (acc:flat? (tpromise-tensor test-tensor-r1-0))) + (check-true (acc:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) @@ -72,7 +73,7 @@ test-tensor-r1-0))) 1) 2))) - (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) + (acc:check-tensor-equal? (↓ test-tcomp-partial-eval) (↓ (tensor 1 2 3))) (define test-id-scalar (get-test-program 'id-scalar)) @@ -80,7 +81,7 @@ (+-ρ test-id-scalar (get-test-program 'sum-nested))) (void (↓ test-id-scalar)) - (flat:check-tensor-equal? (↓ test-force-scalar) + (acc:check-tensor-equal? (↓ test-force-scalar) (↓ (tensor 19 21 20))) (define test-force-subexpr @@ -91,13 +92,13 @@ (+-ρ (get-test-program 'sum-nested) (get-test-program 'sum-nested)))) (void (↓ test-force-subexpr)) - (flat:check-tensor-equal? (↓ test-force-mutate) + (acc:check-tensor-equal? (↓ test-force-mutate) (↓ (tensor 27 33 30))) (define test-tp-r1 (tensor -1 -2 -3)) (define test-force-supexpr (abs-ρ test-tp-r1)) (void (↓ test-force-supexpr)) - (flat:check-tensor-equal? (↓ test-tp-r1) + (acc:check-tensor-equal? (↓ test-tp-r1) (↓ (tensor -1 -2 -3))) (define test-trefs (get-test-program 'tcomp-trefs)) @@ -109,9 +110,9 @@ (check-pred (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) - (vector->list (flat:flat-store (↓ test-build-random))) + (f32vector->list (acc:flat-store (↓ test-build-random))) "Side-effect of generating random tensor must only be run once") - (flat:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) + (acc:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) (eval-res-1-res (get-test-eval-res 'multi-built-tensor))) ) diff --git a/lazy/tensors/test/test-c2-interpreter.rkt b/lazy/tensors/test/test-c2-interpreter.rkt index c20dfdb..be70371 100644 --- a/lazy/tensors/test/test-c2-interpreter.rkt +++ b/lazy/tensors/test/test-c2-interpreter.rkt @@ -1,7 +1,7 @@ (module+ test (require rackunit) (require "B-test-programs.rkt") - (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (for (((test-name test-data) (in-hash test-programs))) (match-define (test-program-data th res) test-data) @@ -9,24 +9,24 @@ ((eval-res-1 res) (let* ((tp (th)) (interped (interp-tensor tp))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? interped res (format "Expected result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp) (flat:shape interped)))) + (check-equal? (tpromise-shape tp) (acc:shape interped)))) ((eval-res-2 res1 res2) (let*-values (((tp1 tp2) (th)) ((interped1) (interp-tensor tp1)) ((interped2) (interp-tensor tp2))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? interped1 res1 (format "Expected first result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp1) (flat:shape interped1)) - (flat:check-tensor-equal? + (check-equal? (tpromise-shape tp1) (acc:shape interped1)) + (acc:check-tensor-equal? interped2 res2 (format "Expected second result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp2) (flat:shape interped2)))))) + (check-equal? (tpromise-shape tp2) (acc:shape interped2)))))) ) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index 0585290..e7be52a 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -3,7 +3,7 @@ (require "B-test-programs.rkt") (require "0-lazy.rkt") (require "c2-interpreter.rkt") - (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define current-test-program-name (make-parameter #f)) (define-check (check-compiler-invariants tp) @@ -18,7 +18,7 @@ ('test-name (current-test-program-name))) (for ((d ds)) (unless (or (number? d) - (flat:flat? d) + (acc:flat? d) (eqv? d 'uncalculated)) (fail-check (format (string-append "Data segment should only contain flat tensors " ", the symbol 'uncalculated or numbers." @@ -27,34 +27,34 @@ (parameterize ((cache (make-hash))) (let* ((instrs-dsr (generate-ds-refs tp)) (interp-dsr (interp-tensor instrs-dsr))) - (unless (flat:tensor-equal? interp-dsr interp-tp) + (unless (acc:tensor-equal? interp-dsr interp-tp) (fail-check (format (string-append "Result of interpreting pass generate-ds-ref doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-dsr)) + "interpretation: ~a~n") + interp-dsr))) (let ((counter (count-references instrs-dsr))) (let* ((extracted (extract-common-subexpressions instrs-dsr counter)) (interp-extracted (interp-tensor extracted))) - (unless (flat:tensor-equal? interp-extracted interp-tp) + (unless (acc:tensor-equal? interp-extracted interp-tp) (fail-check (format (string-append "Result of interpreting pass" " extract-common-subexpression doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-extracted)) + "interpretation: ~a~n") + interp-extracted))) (let* ((gr (generate-racket extracted)) (rkt (compile-racket gr)) (interp-rkt (interp-racket rkt ds))) - (unless (flat:tensor-equal? interp-rkt interp-tp) + (unless (acc:tensor-equal? interp-rkt interp-tp) (fail-check (format (string-append "Result of interpreting compiled racket code doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-rkt)) + "interpretation: ~a~n") + interp-rkt))) (hash-set! (cache) signature rkt) (compile-tensor tp) (unless (eqv? (hash-count (cache)) 1) @@ -167,27 +167,27 @@ (else (error 'cdsr-list->tensor "Unexpected: ~a" l))))] [(tcomp-tref tp _) (count-tcomp-var tp)] [(tcomp-trefs tp _) (count-tcomp-var tp)] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) (let ((c0 (count-tcomp-var tp-t0)) (c1 (count-tcomp-var tp-t1)) (cz (count-tcomp-var tp-z))) (+ c0 c1 cz))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) (let ((ct (count-tcomp-var tp)) (cz (count-tcomp-var zp))) (+ ct cz))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) (let ((ct (count-tcomp-var tp-t)) (cu (count-tcomp-var tp-u))) (+ ct cu))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) (let ((ct (count-tcomp-var tp-t)) (cu (count-tcomp-var tp-u))) (+ ct cu))] - [(tcomp-ext1-ρ-scalar f sign tp) (count-tcomp-var tp)] - [(tcomp-ext1-ρ f sign m shape-fn tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ-scalar f _ sign tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ f _ sign m shape-fn tp) (count-tcomp-var tp)] [(tcomp-reshape s tp) (count-tcomp-var tp)] [(tcomp-ds-ref i) 0] [(tcomp-let lhs rhs body) From febf3aadb214760a4eeb6c5638bcc380630cdbaf Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:45:31 -0400 Subject: [PATCH 83/83] [add-lazy]Add primitive ast nodes to compiled impl --- lazy/autodiff/B-prims.rkt | 66 ++++++++-------- lazy/autodiff/D-test-helpers.rkt | 18 ++++- lazy/ext-ops/test/test-C-star-2-1.rkt | 4 + lazy/ext-ops/test/test-D-sum.rkt | 15 +++- lazy/ext-ops/test/test-E-argmax.rkt | 2 + lazy/ext-ops/test/test-F-max.rkt | 4 + lazy/ext-ops/test/test-K-concat.rkt | 5 ++ lazy/tensors/c0-ast.rkt | 67 ++++++++++++++++ lazy/tensors/c1-racket-runtime.rkt | 31 +++++++- lazy/tensors/c2-interpreter.rkt | 33 +++++++- lazy/tensors/c3-compiler.rkt | 104 +++++++++++++++++++++++++ lazy/tensors/test/test-c3-compiler.rkt | 8 +- 12 files changed, 312 insertions(+), 45 deletions(-) diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index e14ffab..5b4b252 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -1,12 +1,13 @@ #lang racket -(require (only-in "../../accelerated-tensors/ext-impl.rkt" - new-vec - apply-flat-ρ-fn-1 - apply-flat-ρ-fn-2 - apply-flat-∇-fn-1 - apply-flat-∇-fn-2)) +(require (only-in "../tensors/c0-ast.rkt" + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇)) (require "../tensors.rkt") +(require (only-in "../tensors/c1-racket-runtime.rkt" ext2-∇-result)) +(require (only-in "../tensors/c0-ast.rkt" tcomp-ds-ref)) (require "A-autodiff.ss") (struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc) @@ -21,8 +22,12 @@ (set! id (add1 id)) (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da) - (prim1-dual (if #;#f expects-prealloc? (preallocated->functional-1-ρ ρ-fn shape) ρ-fn) - (if #;#f expects-prealloc? (preallocated->functional-1-∇ ∇-fn shape) ∇-fn) + (prim1-dual (if expects-prealloc? + (preallocated->functional-1-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-1-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) da))))))) ;; TODO: Convert the use of force* into the construction of an AST so that we @@ -43,8 +48,12 @@ (set! id (add1 id)) (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da db) - (prim2-dual (if expects-prealloc? (preallocated->functional-2-ρ ρ-fn shape) ρ-fn) - (if expects-prealloc? (preallocated->functional-2-∇ ∇-fn shape) ∇-fn) + (prim2-dual (if expects-prealloc? + (preallocated->functional-2-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-2-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) da db))))))) (define prim2-dual @@ -64,41 +73,28 @@ ;;---------------------------- (define preallocated->functional-1-ρ - (λ (ρ-fn shape-fn) + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) (λ (ra) - (force*1 ra - (λ (ra) - (apply-flat-ρ-fn-1 ρ-fn ra shape-fn)))))) + (tpmake-prim1-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra)))) (define preallocated->functional-1-∇ - (λ (∇-fn shape-fn) + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) (λ (ra z) - (force*2 - (λ () - (values ra z)) - (λ (ra z) - (apply-flat-∇-fn-1 ∇-fn ra z shape-fn)))))) + (tpmake-prim1-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra z)))) (define preallocated->functional-2-ρ - (λ (ρ-fn shape-fn) + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) (λ (ra rb) - (force*2 - (λ () - (values ra rb)) - (λ (ra rb) - (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn)))))) + (tpmake-prim2-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra rb)))) (define preallocated->functional-2-∇ - (λ (∇-fn shape-fn) + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) (λ (ra rb z) - (force*2 - (λ () - (values ra rb)) - (λ (ra rb) - (force*1 - z - (λ (z) - (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn)))))))) + (let ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) + (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) + (values + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 0) + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 1)))))) ;;---------------------------- ;; Dualized tensor op creators diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index dde0e59..a680951 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -1,10 +1,20 @@ #lang racket (require "../tensors.rkt") +(require "../tensors/c0-ast.rkt") (require "A-autodiff.ss") +(require (except-in "../../accelerated-tensors/ext-impl.rkt" + scalarize)) (require rackunit) +(define force-print-store + (λ (t) + (with-output-to-string + (λ () + (print-vec (flat-store (↓ t) + #;(list-ref (unbox (tpromise-dst t)) 0))))))) + (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) (let* ((y (apply fn args)) @@ -14,11 +24,11 @@ ((and (equal-wt? ans-ρ (ρ y)) (equal-wt? grads (ρ g))) (void)) ((equal-wt? ans-ρ (ρ y)) - (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%" - (ρ g) grads))) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ g) grads (map force-print-store (ρ g)) (map force-print-store grads)))) (else - (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%" - (ρ y) ans-ρ)))))) + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ y) ans-ρ (force-print-store (ρ y)) (force-print-store ans-ρ))))))) (define-syntax check-ρ-∇ (syntax-rules () diff --git a/lazy/ext-ops/test/test-C-star-2-1.rkt b/lazy/ext-ops/test/test-C-star-2-1.rkt index bbb2c8b..233466e 100644 --- a/lazy/ext-ops/test/test-C-star-2-1.rkt +++ b/lazy/ext-ops/test/test-C-star-2-1.rkt @@ -6,6 +6,10 @@ (tensor 7 8 9 10))) (b (tensor 2 3 4 5))) (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + (check-ρ-∇ (*-2-1 a b) (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) (tensor 10.0 12.0 14.0 16.0)))) diff --git a/lazy/ext-ops/test/test-D-sum.rkt b/lazy/ext-ops/test/test-D-sum.rkt index e77ab0a..9271f11 100644 --- a/lazy/ext-ops/test/test-D-sum.rkt +++ b/lazy/ext-ops/test/test-D-sum.rkt @@ -5,6 +5,8 @@ (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) (check-ρ-∇ (d-sum a) 12 (list (tensor 1.0 1.0 1.0)))) @@ -50,4 +52,15 @@ (list (tensor (tensor 14.0 16.0 18.0 20.0) (tensor 14.0 16.0 18.0 20.0)) (tensor (tensor 10.0 12.0 14.0 16.0) - (tensor 10.0 12.0 14.0 16.0)))))) + (tensor 10.0 12.0 14.0 16.0))))) + (let ((a (tensor (tensor 0 1 2) + (tensor 3 4 5) + (tensor 6 7 8)))) + (check-ρ-∇ (sum-cols-2 a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))) + (check-ρ-∇ (d-sum-cols a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))))) diff --git a/lazy/ext-ops/test/test-E-argmax.rkt b/lazy/ext-ops/test/test-E-argmax.rkt index f72819e..144980a 100644 --- a/lazy/ext-ops/test/test-E-argmax.rkt +++ b/lazy/ext-ops/test/test-E-argmax.rkt @@ -2,6 +2,8 @@ (require (only-in "../tensors.rkt" tensor)) (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (argmax-1 y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0))) (check-ρ-∇ (d-argmax y) 2.0 (list (tensor 0.0 0.0 0.0 0.0)))) diff --git a/lazy/ext-ops/test/test-F-max.rkt b/lazy/ext-ops/test/test-F-max.rkt index 01ab1a5..88d74a5 100644 --- a/lazy/ext-ops/test/test-F-max.rkt +++ b/lazy/ext-ops/test/test-F-max.rkt @@ -2,6 +2,10 @@ (require rackunit) (require (only-in "../tensors.rkt" tensor)) + (let ((y (tensor 0.0 1.0 0.0 0.0))) + (check-ρ-∇ (max-1 y) 1.0 (list y)) + (check-ρ-∇ (d-max y) 1.0 (list y))) + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) (tensor 0.0 1.0 0.0 0.0) (tensor 1.0 0.0 0.0 0.0) diff --git a/lazy/ext-ops/test/test-K-concat.rkt b/lazy/ext-ops/test/test-K-concat.rkt index b427ced..aa405fc 100644 --- a/lazy/ext-ops/test/test-K-concat.rkt +++ b/lazy/ext-ops/test/test-K-concat.rkt @@ -8,6 +8,11 @@ (define r1-t2 (tensor 5.0 6.0 7.0)) (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + (check-ρ-∇ (concat-1-1 r1-t2 r1-t1) + (tensor 5.0 6.0 7.0 3.0 4.0 5.0 6.0 7.0) + (list (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0 1.0))) + (check-dual-equal? (d-concat r2-t1 r1-t2) (tensor (tensor 3.0 4.0 5.0 6.0 7.0) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt index 1e2d9e1..ae03f89 100644 --- a/lazy/tensors/c0-ast.rkt +++ b/lazy/tensors/c0-ast.rkt @@ -37,6 +37,10 @@ tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) #:transparent) +(struct tcomp-prim1-ρ tcomp (f f-acc sign shape-fn tp) #:transparent) +(struct tcomp-prim1-∇ tcomp (f f-acc sign shape-fn tp zp) #:transparent) +(struct tcomp-prim2-ρ tcomp (f f-acc sign shape-fn tp-t tp-u) #:transparent) +(struct tcomp-prim2-∇ tcomp (f f-acc sign shape-fn tp-t tp-u zp out-ref0 out-ref1 i) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) (struct tcomp-let tcomp (lhs rhs body) #:transparent) (struct tcomp-var tcomp (name) #:transparent) @@ -180,6 +184,25 @@ (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) #"dsr" (number->bytes i))))) +(define gs-prim1-ρ + (λ (prim-sign tp) + (box (list #"p1r" (string->bytes prim-sign) (tpromise-sign tp))))) + +(define gs-prim2-ρ + (λ (signature tp-t tp-u) + (box (list #"p2r" (string->bytes signature) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-prim1-∇ + (λ (signature tp zp) + (box (list #"p1n" (string->bytes signature) (tpromise-sign tp) (tpromise-sign zp))))) + +(define gs-prim2-∇ + (λ (signature tp-t0 tp-t1 tp-z i) + (box (list #"p2n" (string->bytes signature) + (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) + #"dsr" (number->bytes i))))) + (define gs-reshape (λ (shape tp) (box (list* #"r" (tpromise-sign tp) (map number->bytes shape))))) @@ -278,6 +301,42 @@ (gdst-ext2-∇ tp-t0 tp-t1 tp-z) (gs-ext2-∇ prim-sign r0 r1 tp-t0 tp-t1 tp-z i))))) +(define tpmake-prim1-ρ + (λ (f f-acc prim-sign shape-fn tp) + (tpromise (tcomp-prim1-ρ f f-acc prim-sign shape-fn tp) + (shape-fn (tpromise-shape tp)) + (box (list (tpromise-dst tp))) + (gs-prim1-ρ prim-sign tp)))) + +(define tpmake-prim2-ρ + (λ (f f-acc prim-sign shape-fn tp-t tp-u) + (tpromise + (tcomp-prim2-ρ f f-acc prim-sign shape-fn tp-t tp-u) + (shape-fn (tpromise-shape tp-t) (tpromise-shape tp-u)) + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-prim2-ρ prim-sign tp-t tp-u)))) + +(define tpmake-prim1-∇ + (λ (f f-acc prim-sign shape-fn tp zp) + (let ((zp (ensure-tpromise zp))) + (tpromise + (tcomp-prim1-∇ f f-acc prim-sign shape-fn tp zp) + (tpromise-shape tp) + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-prim1-∇ prim-sign tp zp))))) + +(define tpmake-prim2-∇ + (λ (fᵈ fᵈ-acc prim-sign shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (let ((tp-t0 (ensure-tpromise tp-t0)) + (tp-t1 (ensure-tpromise tp-t1)) + (tp-z (ensure-tpromise tp-z))) + (tpromise + (tcomp-prim2-∇ fᵈ fᵈ-acc prim-sign shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (if (zero? i) (tpromise-shape tp-t0) (tpromise-shape tp-t1)) + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) ;; dst constucted in the same way as ext2-∇ + (gs-prim2-∇ prim-sign tp-t0 tp-t1 tp-z i))))) + (define tpmake-reshape (λ (tp shape) (tpromise @@ -295,6 +354,10 @@ (struct-out tcomp-ext2-ρ) (struct-out tcomp-ext1-∇) (struct-out tcomp-ext2-∇) + (struct-out tcomp-prim1-ρ) + (struct-out tcomp-prim2-ρ) + (struct-out tcomp-prim1-∇) + (struct-out tcomp-prim2-∇) (struct-out tcomp-reshape) (struct-out tcomp-let) (struct-out tcomp-var) @@ -314,4 +377,8 @@ tpmake-ext2-ρ tpmake-ext1-∇ tpmake-ext2-∇ + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇ tpmake-reshape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 9fad17f..70561b8 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -55,6 +55,33 @@ (when out-idx1 (data-segment-set! out-idx1 (scalarize (flat s1 g1 0)))))))))) +(define prim2-∇-forcer! + (λ (fᵈ fᵈ-acc fᵈ-sign shape-fn t0 t1 z out-idx0 out-idx1) + (let* ((in-shape-a (flat-shape t0)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape t1)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + v-z 0 1))) + (else + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + (flat-store z) (flat-offset z) out-size))) + (when out-idx0 + (data-segment-set! out-idx0 (scalarize (flat in-shape-a g0 0)))) + (when out-idx1 + (data-segment-set! out-idx1 (scalarize (flat in-shape-b g1 0)))))))) + (define rt:trefs (λ (ft b) (cond @@ -79,4 +106,6 @@ (provide runtime flat? acc:build-tensor acc:list->tensor acc:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ - flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref) + flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + prim2-∇-forcer!) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index ebd84c4..7bbba2c 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -6,7 +6,8 @@ set-ext2-∇-result-res! acc:tref rt:trefs ext2-∇-result-res ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment - data-segment-ref)) + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + data-segment-ref prim2-∇-forcer!)) (define interp-tcomp (λ (tc env) @@ -67,6 +68,36 @@ (scalarize (flat-ext1-ρ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (apply-flat-ρ-fn-1 f (interp-tpromise tp) shape-fn)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (apply-flat-ρ-fn-2 f (interp-tensor tp-t) (interp-tpromise tp-u) shape-fn)] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (apply-flat-∇-fn-1 f (interp-tpromise tp) (scalarize (interp-tpromise zp)) shape-fn)] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (interp-tpromise tp-t0 env)) + (t1-instrs (interp-tpromise tp-t1 env)) + (z-instrs (interp-tpromise tp-z env))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index)))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index))))) + (let* ((out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1))) + (index (if (zero? i) out-idx0 out-idx1)) + (v (data-segment-ref index))) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! fᵈ f-sign shape-fn + t0-instrs t1-instrs + z-instrs out-idx0 out-idx1) + (data-segment-ref index)) + (else v))))] [(tcomp-reshape s tp) (let ([interp-tp (interp-tpromise tp env)]) (flat s (flat-store interp-tp) (flat-offset interp-tp)))] diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 594eb37..15d171c 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -99,6 +99,31 @@ [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-ext1-ρ f f-acc sign m shape-fn tp^) ref^))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-prim1-ρ f f-acc sign shape-fn tp^) ref^))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-prim2-ρ f f-acc sign shape-fn tp-t^ tp-u^) ref^^))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) + ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) + (values (tcomp-prim1-∇ f f-acc sign shape-fn tp^ zp^) ref^^))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^ memo)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^ memo))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0^ tp-t1^ tp-z^ + out0 out1 i) + (add1 ref^^^)))] [(tcomp-reshape s tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-reshape s tp^) ref^))] @@ -188,6 +213,18 @@ (cr-tpromise tp counter^ uid^)] [(tcomp-ext1-ρ f _ sign m shape-fn tp) (cr-tpromise tp counter^ uid^)] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) + (cr-tpromise zp counter-1 uid-1))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) + ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) + (cr-tpromise tp-t1 counter-2 uid-2))] [(tcomp-reshape s tp) (cr-tpromise tp counter^ uid^)] [(tcomp-ds-ref index) (values counter^ uid^)] @@ -316,6 +353,46 @@ (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-ext1-ρ f f-acc sign m shape-fn instrs) tc-counter-data)))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-prim1-ρ f f-acc sign shape-fn instrs) tc-counter-data)))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (->ecs + (ecs-tpromise tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-prim2-ρ f f-acc sign shape-fn t-instrs u-instrs) + tc-counter-data)))))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (->ecs + (ecs-tpromise tp counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise zp counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim1-∇ f f-acc sign shape-fn t-instrs z-instrs) + tc-counter-data)))))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->ecs + (ecs-tpromise tp-t0 counter) + (λ (t0-instrs) + (->ecs + (ecs-tpromise tp-t1 counter) + (λ (t1-instrs) + (->ecs + (ecs-tpromise tp-z counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn + t0-instrs t1-instrs z-instrs + out0 out1 i) + tc-counter-data)))))))] [(tcomp-reshape s tp) (->ecs (ecs-tpromise tp counter) @@ -438,6 +515,33 @@ `(scalarize (flat-ext1-ρ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,instrs))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let ((instrs (gr-tpromise tp))) + `(apply-flat-ρ-fn-1 ,f ,instrs ,shape-fn))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) + `(apply-flat-ρ-fn-2 ,f ,t-instrs ,u-instrs ,shape-fn))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let ((t-instrs (gr-tpromise tp)) + (z-instrs (gr-tpromise zp))) + `(apply-flat-∇-fn-1 ,f ,t-instrs (scalarize ,z-instrs) ,shape-fn))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-tpromise tp-t0)) + (t1-instrs (gr-tpromise tp-t1)) + (z-instrs (gr-tpromise tp-z)) + (out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (let ((index (if (zero? i) out-idx0 out-idx1))) + `(let* ([index ,index] + [v (data-segment-ref index)]) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! ,fᵈ ,fᵈ-acc ,f-sign ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out-idx0 ,out-idx1) + (data-segment-ref index)) + (else v)))))] [(tcomp-reshape s tp) (let ((instrs (gr-tpromise tp))) `(flat ',s diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index e7be52a..6bae788 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -3,6 +3,8 @@ (require "B-test-programs.rkt") (require "0-lazy.rkt") (require "c2-interpreter.rkt") + (require (prefix-in acc: (only-in "../../accelerated-tensors/autodiff.rkt" + make-printable))) (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define current-test-program-name (make-parameter #f)) @@ -14,7 +16,7 @@ (('data-segment ds) ('signature signature) ('input-computation (tpromise-tensor tp)) - ('expected-interpretation interp-tp) + ('expected-interpretation (acc:make-printable interp-tp)) ('test-name (current-test-program-name))) (for ((d ds)) (unless (or (number? d) @@ -44,7 +46,7 @@ " extract-common-subexpression doesn't" " match expected interpretation. Actual " "interpretation: ~a~n") - interp-extracted))) + (acc:make-printable interp-extracted)))) (let* ((gr (generate-racket extracted)) (rkt (compile-racket gr)) (interp-rkt (interp-racket rkt ds))) @@ -54,7 +56,7 @@ "Result of interpreting compiled racket code doesn't" " match expected interpretation. Actual " "interpretation: ~a~n") - interp-rkt))) + (acc:make-printable interp-rkt)))) (hash-set! (cache) signature rkt) (compile-tensor tp) (unless (eqv? (hash-count (cache)) 1)