From 0930df7a1d1a2dc161e30e2a480485a5b3614f89 Mon Sep 17 00:00:00 2001 From: Darshal Shetty Date: Tue, 16 Jul 2024 20:40:54 -0400 Subject: [PATCH] [add-lazy]Switch compiled tensor runtime to acc tensor impl --- lazy/autodiff/A-autodiff.rkt | 5 +- lazy/autodiff/B-prims.rkt | 89 +++++++-- lazy/autodiff/E-print.rkt | 2 +- lazy/autodiff/test/test-E-print.rkt | 50 ++--- lazy/ext-ops/A-scalar-ops.rkt | 134 ++++++++++---- lazy/ext-ops/B-comparators.rkt | 36 ++-- lazy/ext-ops/C-star-2-1.rkt | 60 ++++-- lazy/ext-ops/D-sum.rkt | 76 ++++++-- lazy/ext-ops/E-argmax.rkt | 43 ++++- lazy/ext-ops/F-max.rkt | 57 +++++- lazy/ext-ops/G-correlate.rkt | 89 +++++++-- lazy/ext-ops/I-flatten.rkt | 25 ++- lazy/ext-ops/K-concat.rkt | 55 +++++- lazy/ext-ops/test/test-G-correlate.rkt | 4 +- lazy/tensors/0-lazy.rkt | 90 ++++----- lazy/tensors/1-reflect.rkt | 9 +- lazy/tensors/A-equality.rkt | 6 +- lazy/tensors/B-test-programs.rkt | 212 ++++++++++++++-------- lazy/tensors/c0-ast.rkt | 54 +++--- lazy/tensors/c1-racket-runtime.rkt | 59 +++--- lazy/tensors/c2-interpreter.rkt | 28 +-- lazy/tensors/c3-compiler.rkt | 86 +++++---- lazy/tensors/test/test-1-reflect.rkt | 29 +-- lazy/tensors/test/test-c2-interpreter.rkt | 14 +- lazy/tensors/test/test-c3-compiler.rkt | 34 ++-- 25 files changed, 908 insertions(+), 438 deletions(-) diff --git a/lazy/autodiff/A-autodiff.rkt b/lazy/autodiff/A-autodiff.rkt index c23a9ba..cae86b0 100644 --- a/lazy/autodiff/A-autodiff.rkt +++ b/lazy/autodiff/A-autodiff.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../tensors.rkt") ;;---------------------------- @@ -52,7 +53,7 @@ (hash-set σ d (+-ρ z g))))) (define +-ρ - (ext2-ρ + 0 0)) + (ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0)) ;;---------------------------- ;; Reverse-mode AD @@ -111,7 +112,7 @@ ((dual? v) (trace-print (ρ v) port)) (else (fprintf port "~a~%" v))))) -(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s)) +(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s)) (include "test/test-A-autodiff.rkt") diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index 94796d3..e14ffab 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -1,21 +1,32 @@ #lang racket +(require (only-in "../../accelerated-tensors/ext-impl.rkt" + new-vec + apply-flat-ρ-fn-1 + apply-flat-ρ-fn-2 + apply-flat-∇-fn-1 + apply-flat-∇-fn-2)) (require "../tensors.rkt") (require "A-autodiff.ss") -(struct prim (ρ-fn ∇-fn shape-fn signature expects-prealloc? proc) +(struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc) #:property prop:procedure (λ (this . args) (apply (prim-proc this) args))) +;;TODO: Add new ast nodes for the 4 forces being done in the four preallocated->functional-* functions (define prim1 (let ((id 0)) - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) (let ((prim-sign (string-append "p1" (~r id #:base 16)))) (set! id (add1 id)) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da) - (prim1-dual ρ-fn ∇-fn da))))))) + (prim1-dual (if #;#f expects-prealloc? (preallocated->functional-1-ρ ρ-fn shape) ρ-fn) + (if #;#f expects-prealloc? (preallocated->functional-1-∇ ∇-fn shape) ∇-fn) + da))))))) +;; TODO: Convert the use of force* into the construction of an AST so that we +;; don't prematurely trigger computation. (define prim1-dual (λ (ρ-fn ∇-fn da) (let ((ra (ρ da))) @@ -27,12 +38,14 @@ (define prim2 (let ((id 0)) - (λ (ρ-fn ∇-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) + (λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f]) (let ((prim-sign (string-append "p2" (~r id #:base 16)))) (set! id (add1 id)) - (prim ρ-fn ∇-fn shape prim-sign expects-prealloc? + (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da db) - (prim2-dual ρ-fn ∇-fn da db))))))) + (prim2-dual (if expects-prealloc? (preallocated->functional-2-ρ ρ-fn shape) ρ-fn) + (if expects-prealloc? (preallocated->functional-2-∇ ∇-fn shape) ∇-fn) + da db))))))) (define prim2-dual (λ (ρ-fn ∇-fn da db) @@ -45,6 +58,48 @@ (let ((σ-hat ((κ da) da ga σ))) ((κ db) db gb σ-hat))))))))) +;;---------------------------- +;; Managing flat-optimized and +;; non-flat ρ and ∇ functions +;;---------------------------- + +(define preallocated->functional-1-ρ + (λ (ρ-fn shape-fn) + (λ (ra) + (force*1 ra + (λ (ra) + (apply-flat-ρ-fn-1 ρ-fn ra shape-fn)))))) + +(define preallocated->functional-1-∇ + (λ (∇-fn shape-fn) + (λ (ra z) + (force*2 + (λ () + (values ra z)) + (λ (ra z) + (apply-flat-∇-fn-1 ∇-fn ra z shape-fn)))))) + +(define preallocated->functional-2-ρ + (λ (ρ-fn shape-fn) + (λ (ra rb) + (force*2 + (λ () + (values ra rb)) + (λ (ra rb) + (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn)))))) + +(define preallocated->functional-2-∇ + (λ (∇-fn shape-fn) + (λ (ra rb z) + (force*2 + (λ () + (values ra rb)) + (λ (ra rb) + (force*1 + z + (λ (z) + (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn)))))))) + ;;---------------------------- ;; Dualized tensor op creators ;;---------------------------- @@ -53,10 +108,12 @@ (unless (prim? f) (error 'ext1-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim1 - (ext1-ρ (prim-ρ-fn f) n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) - (ext1-∇ (prim-∇-fn f) n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) + (ext1-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext1-∇ (prim-∇-fn f) (prim-∇-acc-fn f) n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) (prim-shape-fn f) #f))) @@ -65,10 +122,12 @@ (unless (prim? f) (error 'ext2-prim "Function to be extended must be a primitive. Found: ~a" f)) (prim2 - (ext2-ρ (prim-ρ-fn f) m n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) - (ext2-∇ (prim-∇-fn f) m n (prim-shape-fn f) - (prim-expects-prealloc? f) (prim-signature f)) + (ext2-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "r" (prim-signature f))) + (prim-ρ-acc-fn f) + (ext2-∇ (prim-∇-fn f) (prim-∇-acc-fn f) m n (prim-shape-fn f) + (prim-expects-prealloc? f) (string-append "n" (prim-signature f))) + (prim-∇-acc-fn f) (prim-shape-fn f) #f))) diff --git a/lazy/autodiff/E-print.rkt b/lazy/autodiff/E-print.rkt index 270083d..b39f25a 100644 --- a/lazy/autodiff/E-print.rkt +++ b/lazy/autodiff/E-print.rkt @@ -3,7 +3,7 @@ (require "A-autodiff.rkt") (require "../tensors/0-lazy.rkt") (require "../tensors/1-reflect.rkt") -(require (except-in "../../flat-tensors/ext-impl.rkt" scalarize)) +(require (except-in "../../accelerated-tensors/ext-impl.rkt" scalarize)) (define max-tensor-print-length (make-parameter 5)) diff --git a/lazy/autodiff/test/test-E-print.rkt b/lazy/autodiff/test/test-E-print.rkt index 91fde6c..092f78b 100644 --- a/lazy/autodiff/test/test-E-print.rkt +++ b/lazy/autodiff/test/test-E-print.rkt @@ -18,54 +18,54 @@ deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor)) - (check-equal? (make-printable long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable deep-tensor 3) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...))) (check-equal? (make-printable deeper-tensor 3) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...))) (parameterize ((max-tensor-print-length 3)) - (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...))) + (check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...))) (check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor)) (list - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) (fake-tensor (list (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) (fake-tensor - (list (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) - (fake-tensor '(1 2 3 ...)) + (list (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) + (fake-tensor '(1.0 2.0 3.0 ...)) '...)) '...)))))) diff --git a/lazy/ext-ops/A-scalar-ops.rkt b/lazy/ext-ops/A-scalar-ops.rkt index 2049b41..72b33f9 100644 --- a/lazy/ext-ops/A-scalar-ops.rkt +++ b/lazy/ext-ops/A-scalar-ops.rkt @@ -1,49 +1,108 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ ext2-ρ)) (require "../autodiff.rkt") +(define +-0-0-ρ-acc + (λ (a b) + "@{a}+@{b}")) + (define +-0-0 (prim2 + + +-0-0-ρ-acc + (λ (a b z) + (values z z)) (λ (a b z) (values z z)))) +(define --0-0-ρ-acc + (λ (a b) + "@{a}-@{b}")) + (define --0-0 (prim2 - + --0-0-ρ-acc + (λ (a b z) + (values z (- z))) (λ (a b z) - (values z (- z))))) + (values z "(- @{z})")))) + +(define *-0-0-ρ-acc + (λ (a b) + "@{a}*@{b}")) (define *-0-0 (prim2 * + *-0-0-ρ-acc + (λ (a b z) + (values (* b z) (* a z))) (λ (a b z) - (values (* b z) (* a z))))) + (values "@{b}*@{z}" "@{a}*@{z}")))) + +(define /-0-0-ρ-acc + (λ (a b) + "@{a}/@{b}")) (define /-0-0 (prim2 / - (λ (a b z) - (values (* z (/ 1 b)) - (* z (/ (- a) (* b b))))))) + /-0-0-ρ-acc + (λ (a b z) + (values (* z (/ 1 b)) + (* z (/ (- a) (* b b))))) + (λ (a b z) + (values "(@{z} * (1 / @{b}))" + "(@{z} * ((- @{a}) / (@{b} * @{b})))")))) + +(define expt-0-0-ρ-acc + (λ (a b) + "pow(@{a}, @{b})")) (define expt-0-0 (prim2 expt - (λ (a b z) - (values (* z (* b (expt a (- b 1)))) - (* z (* (expt a b) (log a))))))) + expt-0-0-ρ-acc + (λ (a b z) + (values (* z (* b (expt a (- b 1)))) + (* z (* (expt a b) (log a))))) + (λ (a b z) + (values "(@{z} * (@{b} * pow(@{a}, (@{b} - 1))))" + "(@{z} * (pow(@{a}, @{b}) * log(@{a})))")))) + +(define exp-0-ρ-acc + (λ (a) + "exp(@{a})")) (define exp-0 (prim1 exp - (λ (a z) - (* z (exp a))))) + exp-0-ρ-acc + (λ (a z) + (* z (exp a))) + (λ (a z) + "(@{z} * exp(@{a}))"))) + +(define log-0-ρ-acc + (λ (a) + "log(@{a})")) (define log-0 (prim1 log - (λ (a z) - (* z (/ 1 a))))) + log-0-ρ-acc + (λ (a z) + (* z (/ 1 a))) + (λ (a z) + "(@{z} * (1 / @{a}))"))) + +(define sqrt-0-ρ-acc + (λ (a) + "sqrt(@{a})")) (define sqrt-0 (prim1 sqrt - (λ (x z) - (/ z (* 2 (sqrt x)))))) + sqrt-0-ρ-acc + (λ (x z) + (/ z (* 2 (sqrt x)))) + (λ (x z) + "(@{z} / (2 * sqrt(@{x})))"))) (define abs-0-ρ (λ (x) @@ -51,14 +110,22 @@ ((< x 0) (* -1 x)) (else x)))) +(define abs-0-ρ-acc + (λ (x) + "fabs(@{x})")) + (define abs-0-∇ (λ (x z) (cond ((< x 0) (- z)) (else z)))) +(define abs-0-∇-acc + (λ (x z) + "sign(@{x}) * @{z}")) + (define abs-0 - (prim1 abs-0-ρ abs-0-∇)) + (prim1 abs-0-ρ abs-0-ρ-acc abs-0-∇ abs-0-∇-acc)) (define rectify-0-ρ (λ (s) @@ -66,17 +133,25 @@ ((< s 0.0) 0.0) (else s)))) +(define rectify-0-ρ-acc + (λ (s) + "fmax(0.0f, @{s})")) + (define rectify-0-∇ (λ (s z) (cond ((< s 0.0) 0.0) (else z)))) +(define rectify-0-∇-acc + (λ (s z) + "step(0, @{s}) * @{z}")) + (define rectify-shape (λ (s) s)) (define rectify-0 - (prim1 rectify-0-ρ rectify-0-∇ rectify-shape)) + (prim1 rectify-0-ρ rectify-0-ρ-acc rectify-0-∇ rectify-0-∇-acc rectify-shape)) ;;------------------------------------ ;; differentiable extended functions. @@ -102,32 +177,29 @@ ;; non-differentiable extended functions. ;;------------------------------------ -(define *-ρ (ext2-ρ * 0 0)) -(define +-ρ (ext2-ρ + 0 0)) -(define --ρ (ext2-ρ - 0 0)) -(define /-ρ (ext2-ρ / 0 0)) -(define expt-ρ (ext2-ρ expt 0 0)) - -(define exp-ρ (ext1-ρ exp 0)) -(define log-ρ (ext1-ρ log 0)) -(define abs-ρ (ext1-ρ abs-0-ρ 0)) -(define rectify-ρ (ext1-ρ rectify-0-ρ 0)) +(define *-ρ (ext2-ρ * *-0-0-ρ-acc 0 0)) +(define +-ρ (ext2-ρ + +-0-0-ρ-acc 0 0)) +(define --ρ (ext2-ρ - --0-0-ρ-acc 0 0)) +(define /-ρ (ext2-ρ / /-0-0-ρ-acc 0 0)) +(define expt-ρ (ext2-ρ expt expt-0-0-ρ-acc 0 0)) -(define sqrt-ρ - (λ (a) - (expt-ρ a 1/2))) +(define exp-ρ (ext1-ρ exp exp-0-ρ-acc 0)) +(define log-ρ (ext1-ρ log log-0-ρ-acc 0)) +(define abs-ρ (ext1-ρ abs-0-ρ abs-0-ρ-acc 0)) +(define rectify-ρ (ext1-ρ rectify-0-ρ rectify-0-ρ-acc 0)) +(define sqrt-ρ (ext1-ρ sqrt sqrt-0-ρ-acc 0)) (define sqr-ρ (λ (x) (*-ρ x x))) (define zeroes-ρ - (ext1-ρ (λ (_) 0.0) 0)) + (ext1-ρ (λ (_) 0.0) (λ (_) "0.0") 0)) (include "test/test-A-scalar-ops.rkt") (provide +-0-0 --0-0 *-0-0 /-0-0 expt-0-0 - exp-0 log-0 sqrt-0 abs-0 rectify-0 + exp-0 log-0 abs-0 rectify-0 sqrt-0 d+ d- d* d/ d-expt d-exp d-log d-abs diff --git a/lazy/ext-ops/B-comparators.rkt b/lazy/ext-ops/B-comparators.rkt index 7fcb184..3db8e0f 100644 --- a/lazy/ext-ops/B-comparators.rkt +++ b/lazy/ext-ops/B-comparators.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require "../autodiff.rkt") ;;---------------------------- @@ -24,7 +25,7 @@ (comparator >)) (define >=-0-0 - (comparator >=)) + (comparator >)) ;;---------------------------- ;; Tensorized comparators @@ -37,6 +38,11 @@ ((f (ρ da) (ρ db)) 1.0) (else 0.0))))) +(define comparator-ρ-acc + (λ (f) + (λ (a b) + "@{a} @{f} @{b}"))) + (define comparator-∇ (λ (f) (λ (da db z) @@ -44,40 +50,48 @@ ((f (ρ da) (ρ db)) (values z z)) (else (values 0.0 0.0)))))) +(define comparator-∇-acc + (λ (f) + (λ (a b z) + (let ((bool "@{a} @{f} @{b}")) + (values "@{bool}*@{z}" "@{bool}*@{z}"))))) + (define comparator-shape (λ (f) (λ (sa sb) sa))) (define comparator-prim - (λ (f) - (prim2 (comparator-ρ f) (comparator-∇ f) (comparator-shape f)))) + (λ (f f-acc) + (prim2 (comparator-ρ f) (comparator-ρ-acc f-acc) + (comparator-∇ f) (comparator-∇-acc f-acc) + (comparator-shape f)))) (define extended-comparator - (λ (f) - (ext2 (comparator-prim f) 0 0))) + (λ (f f-acc) + (ext2 (comparator-prim f f-acc) 0 0))) (define =-1 - (extended-comparator =)) + (extended-comparator = "==")) (define <-1 - (extended-comparator <)) + (extended-comparator < "<")) (define >-1 - (extended-comparator >)) + (extended-comparator > ">")) (define <=-1 - (extended-comparator <=)) + (extended-comparator <= "<=")) (define >=-1 - (extended-comparator >=)) + (extended-comparator >= ">=")) (define != (λ (a b) (not (= a b)))) (define !=-1 - (extended-comparator !=)) + (extended-comparator != "!=")) (include "test/test-B-comparators.rkt") diff --git a/lazy/ext-ops/C-star-2-1.rkt b/lazy/ext-ops/C-star-2-1.rkt index 629b1ba..fab49f2 100644 --- a/lazy/ext-ops/C-star-2-1.rkt +++ b/lazy/ext-ops/C-star-2-1.rkt @@ -1,5 +1,7 @@ #lang racket +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") (require (only-in "../tensors.rkt" ext2-ρ)) (require "../autodiff.rkt") @@ -8,35 +10,69 @@ v1 i1 stride1 v-out i-out stride-out) (for ([i (in-range 0 stride-out)]) - (vector-set! v-out (+ i-out i) - (* (vector-ref v0 (+ i0 i)) - (vector-ref v1 (+ i1 (modulo i stride1)))))))) + (vset! v-out (+ i-out i) + (* (vref v0 (+ i0 i)) + (vref v1 (+ i1 (modulo i stride1)))))))) + +(define *-2-1-base-ρ-acc + (λ (v0 i0 stride0 + v1 i1 stride1 + v-out i-out stride-out) + #< v max) (values v (+ (- i i0) 0.0))) (else (values max max-i)))))))) +(define argmax-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< max) { + max = v; + max_i = i - @{i0} + 0.0; + } + } + @{v-out}[@{i-out}] = max_i; +EOF + )) + (define argmax-1-∇ (λ (g0 v0 i0 stride0 vz iz stride-z) - (let ((z (vector-ref vz iz))) + (let ((z (vref vz iz))) (for ([i (in-range i0 (+ i0 stride0))]) - (vector-set! g0 i 0.0))))) + (vset! g0 i 0.0))))) + +(define argmax-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< v max) v) (else max))))))) +(define max-1-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< v max) (values v (- i i0))) (else (values max max-i)))))))) +(define max-1-∇-acc + (λ (g0 v0 i0 stride0 + vz iz stride-z) + #< max) { + max = v; + max_i = i - @{i0}; + } + } + for(int i=@{i0}; i<@{i0}+@{stride0}; i++) { + if(i == @{i0}+max_i) { + @{g0}[i] += z; + } else { + @{g0}[i] += 0.0; + } + } +EOF + )) + (define max-shape (λ (st) (cdr st))) (define max-1 - (prim1 max-1-ρ max-1-∇ max-shape #t)) + (prim1 max-1-ρ max-1-ρ-acc max-1-∇ max-1-∇-acc max-shape #t)) (define d-max (ext1 max-1 1)) (define max-ρ - (ext1-ρ max-1-ρ 1 max-shape #t)) + (ext1-ρ max-1-ρ max-1-ρ-acc 1 max-shape #t)) (include "test/test-F-max.rkt") diff --git a/lazy/ext-ops/G-correlate.rkt b/lazy/ext-ops/G-correlate.rkt index cfe156b..14c0cb6 100644 --- a/lazy/ext-ops/G-correlate.rkt +++ b/lazy/ext-ops/G-correlate.rkt @@ -1,5 +1,7 @@ #lang racket +(require string-interpolation) +(require "../../accelerated-tensors/ext-impl.rkt") (require (only-in "../tensors.rkt" ext2-ρ len)) (require "../autodiff.rkt") @@ -17,17 +19,39 @@ (let* ((i1-min (- i1 (modulo i1 nd))) (i1-max (+ i1-min nd))) (for ((i (in-range 0 b))) - (vector-set! v-out (+ i-out i) + (vset! v-out (+ i-out i) (for/fold ([sum 0.0]) ([j (in-range 0 md)]) (let ((ai (+ i0 (* i md) j)) (bi (- (+ i1 j) qd))) (cond ((and (>= bi i1-min) (< bi i1-max)) - (let ((a (vector-ref v0 ai)) - (b (vector-ref v1 bi))) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) (+ sum (* a b)))) (else sum)))))))))) +(define correlate-3-1-ρ-acc + (λ (nd md qd) + (λ (v0 i0 _ + v1 i1 d + v-out i-out b) + #<= i1_min && bi < i1_max) { + sum += @{v0}[ai] * @{v1}[bi]; + } + } + @{v-out}[@{i-out}+i] = sum; + } +EOF + ))) + (define correlate-3-1-∇ (λ (nd md qd) (λ (g0 g1 @@ -37,17 +61,55 @@ (let* ((i1-min (- i1 (modulo i1 nd))) (i1-max (+ i1-min nd))) (for ((i (in-range 0 b))) - (let ((z (vector-ref vz (+ iz i)))) + (let ((z (vref vz (+ iz i)))) (for ([j (in-range 0 md)]) (let ((ai (+ i0 (* i md) j)) (bi (- (+ i1 j) qd))) (when (and (>= bi i1-min) (< bi i1-max)) - (let ((a (vector-ref v0 ai)) - (b (vector-ref v1 bi))) - (vector-set! g0 ai - (+ (vector-ref g0 ai) (* z b))) - (vector-set! g1 bi - (+ (vector-ref g1 bi) (* z a))))))))))))) + (let ((a (vref v0 ai)) + (b (vref v1 bi))) + (vset! g0 ai + (+ (vref g0 ai) (* z b))) + (vset! g1 bi + (+ (vref g1 bi) (* z a))))))))))))) + +(define correlate-3-1-∇-acc + (λ (nd md qd) + (λ (g + v0 i0 bmd + v1 i1 d + vz iz b) + (values + #<= i1_min && bi < i1_max) { + @{g}[ai] += z * @{v1}[bi]; + } + } + } +EOF + + #<= i1_min && bi < i1_max) { + @{g}[bi] += z * @{v0}[ai]; + } + } + } +EOF + )))) (define correlate-shape (λ (bmd nd) @@ -57,9 +119,10 @@ (λ (nd md qd) (prim2 (correlate-3-1-ρ nd md qd) + (correlate-3-1-ρ-acc nd md qd) (correlate-3-1-∇ nd md qd) - correlate-shape - #t))) + (correlate-3-1-∇-acc nd md qd) + correlate-shape #t))) (define d-correlate (λ (bank signal) @@ -83,7 +146,7 @@ (q (/ (- m 1) 2)) ;; This is the padding. (qd (* q d)) (md (* m d))) - ((ext2-ρ (correlate-3-1-ρ nd md qd) 3 1 correlate-shape #t) + ((ext2-ρ (correlate-3-1-ρ nd md qd) (correlate-3-1-ρ-acc nd md qd) 3 1 correlate-shape #t) bank signal)))) (define last diff --git a/lazy/ext-ops/I-flatten.rkt b/lazy/ext-ops/I-flatten.rkt index bf24773..0ef02f6 100644 --- a/lazy/ext-ops/I-flatten.rkt +++ b/lazy/ext-ops/I-flatten.rkt @@ -1,5 +1,6 @@ #lang racket +(require string-interpolation) (require (only-in "../tensors.rkt" ext1-ρ tref reshape shape ref)) (require (only-in "../autodiff.rkt" prim1 ext1)) @@ -7,10 +8,30 @@ (λ (t) (reshape (flatten-shape (shape t)) t))) +(define flatten-2-ρ-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #<preallocated-1-ρ f base-shape shape-fn-out))) - (tpmake-ext1-ρ flat-f prim-sign m shape-fn tp out-shape))])))))) + (let ((flat-f (functional->preallocated-1-ρ f base-shape shape-fn-out)) + (flat-f-acc (functional->preallocated-1-ρ-acc f-acc base-shape shape-fn-out))) + (tpmake-ext1-ρ flat-f flat-f-acc prim-sign m shape-fn tp out-shape))])))))) ;; See comment for tp-ext1-ρ (define tp-ext2-ρ (let ((id -1)) - (λ (f m n + (λ (f f-acc m n [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin @@ -216,19 +217,21 @@ instructions refering to the same gensym variable [(and (tpromise? tp-t) (tpromise? tp-u) (null? (tpromise-shape tp-t)) (null? (tpromise-shape tp-u))) - (tpmake-ext2-ρ-scalar f prim-sign tp-t tp-u sf-out)] + (tpmake-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u sf-out)] [expects-prealloc? (tpmake-ext2-ρ tp-t tp-u - f prim-sign m n shape-fn + f f-acc prim-sign m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out)))] [else (let ((flat-f (functional->preallocated-2-ρ - f sf0 sf1 sf-out))) + f sf0 sf1 sf-out)) + (flat-f-acc (functional->preallocated-2-ρ-acc + f-acc sf0 sf1 sf-out))) (tpmake-ext2-ρ tp-t tp-u - flat-f prim-sign m n shape-fn + flat-f flat-f-acc prim-sign m n shape-fn (ext2-shapes s0 s1 m n sf-out (λ (s-out . _) s-out))))])))))) @@ -238,7 +241,7 @@ instructions refering to the same gensym variable ;; See comment for tp-ext1-ρ (define tp-ext1-∇ (let ((id -1)) - (λ (f m + (λ (f f-acc m [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin @@ -249,58 +252,63 @@ instructions refering to the same gensym variable (cond ((number? tp) (f tp zp)) (expects-prealloc? - (tpmake-ext1-∇ tp zp f prim-sign m shape-fn (tp-shape tp))) + (tpmake-ext1-∇ tp zp f f-acc prim-sign m shape-fn (tp-shape tp))) (else (let* ((in-shape (tpromise-shape tp)) (base-shape (min-shape m in-shape)) (out-shape (shape-fn base-shape)) - (flat-f (functional->preallocated-1-∇ f base-shape out-shape))) - (tpmake-ext1-∇ tp zp flat-f prim-sign m shape-fn (tp-shape tp))))))))) + (flat-f (functional->preallocated-1-∇ f base-shape out-shape)) + (flat-f-acc (functional->preallocated-1-∇-acc f-acc base-shape out-shape))) + (tpmake-ext1-∇ tp zp flat-f flat-f-acc prim-sign m shape-fn (tp-shape tp))))))))) ;; See comment for tp-ext1-ρ (define tp-ext2-∇ (let ((id -1)) - (λ (f m n + (λ (f f-acc m n [shape-fn scalar-shape] [expects-prealloc? #f] [prim-sign (begin (set! id (add1 id)) (string-append "ne2" (~r id #:base 16)))]) (let ((tp-f - (λ (f tp-t tp-u tp-z) - (tp-d-ext2^ f prim-sign m n shape-fn + (λ (f f-acc tp-t tp-u tp-z) + (tp-d-ext2^ f f-acc prim-sign m n shape-fn tp-t tp-u tp-z)))) (λ (tp-t tp-u tp-z) (cond (expects-prealloc? - (tp-f f tp-t tp-u tp-z)) + (tp-f f f-acc tp-t tp-u tp-z)) [else (let* ((t-shape (min-shape m (tp-shape tp-t))) (u-shape (min-shape n (tp-shape tp-u))) (out-shape (shape-fn t-shape u-shape)) (flat-f (functional->preallocated-2-∇ - f t-shape u-shape out-shape))) - (tp-f flat-f tp-t tp-u tp-z))])))))) + f t-shape u-shape out-shape)) + (flat-f-acc (functional->preallocated-2-∇-acc + f-acc t-shape u-shape out-shape))) + (tp-f flat-f flat-f-acc tp-t tp-u tp-z))])))))) (define tp-d-ext2^ - (λ (fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) + (λ (fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z) (let* ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) (values - (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 0 (tp-shape tp-t0)) - (tpmake-ext2-∇ fᵈ sign r0 r1 shape-fn + (tpmake-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 1 (tp-shape tp-t1)))))) (define tp-rank (λ (tp) - (flat:len (tp-shape tp)))) + (acc:len (tp-shape tp)))) (define tp-reshape (λ (s tp) (cond - ((= (flat:size-of s) (flat:size-of (tpromise-shape tp))) + ((and (tpromise? tp) (= (acc:size-of s) (acc:size-of (tpromise-shape tp)))) (tpmake-reshape tp s)) - (else (error 'shape-error "Cannot reshape ~a to ~a~%" (tpromise-shape tp) s))))) + [(and (acc:flat? tp) (= (acc:size-of s) (acc:size-of (acc:shape tp)))) + (acc:reshape s tp)] + (else (error 'shape-error "Cannot reshape ~a to ~a~%" tp s))))) (define tensor? (lambda (tp) @@ -311,9 +319,9 @@ instructions refering to the same gensym variable (provide start-vector-manager vector-manager-report) (provide (rename-out - (flat:len len) - (flat:ref ref) - (flat:refr refr))) + (acc:len len) + (acc:ref ref) + (acc:refr refr))) (provide tensor tpromise? (rename-out @@ -335,4 +343,4 @@ instructions refering to the same gensym variable (tp-rank rank) (tp-shape shape) (tp-reshape reshape) - (flat:size-of size-of))) + (acc:size-of size-of))) diff --git a/lazy/tensors/1-reflect.rkt b/lazy/tensors/1-reflect.rkt index df6575f..bc2a0c9 100644 --- a/lazy/tensors/1-reflect.rkt +++ b/lazy/tensors/1-reflect.rkt @@ -1,6 +1,6 @@ #lang racket -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (require "c0-ast.rkt") (require (only-in "c3-compiler.rkt" compiler-cache @@ -42,11 +42,10 @@ (cond [(and (tpromise? tp) (null? (tpromise-shape tp))) (tp-scalarize (↓ tp))] - [(and (flat:flat? tp) (null? (flat:flat-shape tp))) - (vector-ref (flat:flat-store tp) 0)] + [(and (acc:flat? tp) (null? (acc:flat-shape tp))) + (vector-ref (acc:flat-store tp) 0)] [else tp]))) -;; TODO: these force functions will be moved to the openCL runtime (define force*1 (λ (t f) (f (↓ t)))) diff --git a/lazy/tensors/A-equality.rkt b/lazy/tensors/A-equality.rkt index 7c420ca..8cfa8ba 100644 --- a/lazy/tensors/A-equality.rkt +++ b/lazy/tensors/A-equality.rkt @@ -1,11 +1,11 @@ #lang racket (require "1-reflect.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define tp-tensor-equal? (λ (tp-actual tp-expected) - (flat:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) + (acc:tensor-equal? (↓ tp-actual) (↓ tp-expected)))) (require rackunit) (define-binary-check (tp-check-tensor-equal? tp-tensor-equal? actual expected)) @@ -13,6 +13,6 @@ (include "test/test-A-equality.rkt") (provide (rename-out - (flat:tolerance tolerance) + (acc:tolerance tolerance) (tp-tensor-equal? tensor-equal?) (tp-check-tensor-equal? check-tensor-equal?))) diff --git a/lazy/tensors/B-test-programs.rkt b/lazy/tensors/B-test-programs.rkt index 450de2f..f301c73 100644 --- a/lazy/tensors/B-test-programs.rkt +++ b/lazy/tensors/B-test-programs.rkt @@ -1,7 +1,8 @@ #lang racket +(require string-interpolation) (require "0-lazy.rkt") -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define make-tref-test-program (λ (t) @@ -22,29 +23,29 @@ 'tensor-r1-0 (test-program-data (λ () (tensor 1 2 3)) - (eval-res-1 (flat:tensor 1 2 3))) + (eval-res-1 (acc:tensor 1 2 3))) 'tensor-r1-1 (test-program-data (λ () (tensor 1 2 3 4 5)) - (eval-res-1 (flat:tensor 1 2 3 4 5))) + (eval-res-1 (acc:tensor 1 2 3 4 5))) 'tensor-r1-2 (test-program-data (λ () (tensor 3.0 4.0 5.0)) - (eval-res-1 (flat:tensor 3.0 4.0 5.0))) + (eval-res-1 (acc:tensor 3.0 4.0 5.0))) 'tensor-r2-0 (test-program-data (λ () (tensor (tensor 1 2 3) (tensor 4 5 6))) - (eval-res-1 (flat:tensor (flat:tensor 1 2 3) (flat:tensor 4 5 6)))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) (acc:tensor 4 5 6)))) 'tensor-r2-1 (test-program-data (λ () (reshape '(2 3) (tensor 3.0 4.0 5.0 7.0 8.0 9.0))) (eval-res-1 - (flat:reshape '(2 3) (flat:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) + (acc:reshape '(2 3) (acc:tensor 3.0 4.0 5.0 7.0 8.0 9.0)))) 'build-tensor-r1-0 (test-program-data (λ () (build-tensor '(6) (λ (i) (* 3.0 (car i))))) - (eval-res-1 (flat:build-tensor '(6) + (eval-res-1 (acc:build-tensor '(6) (λ (i) (* 3.0 (car i)))))) 'build-tensor-r2-0 (test-program-data (λ () @@ -52,7 +53,7 @@ (λ (i) (match-define `(,x ,y) i) (* 2.0 (+ (* x 6) y))))) - (eval-res-1 (flat:build-tensor '(5 6) + (eval-res-1 (acc:build-tensor '(5 6) (λ (i) (match-define `(,x ,y) i) (* 2.0 (+ (* x 6) y)))))) @@ -62,7 +63,7 @@ (λ (i) (match-define `(,x ,y) i) (* 3.0 (+ (* x 6) y))))) - (eval-res-1 (flat:build-tensor '(3 6) + (eval-res-1 (acc:build-tensor '(3 6) (λ (i) (match-define `(,x ,y) i) (* 3.0 (+ (* x 6) y)))))) @@ -72,7 +73,7 @@ (λ (i) (match-define `(,x ,y ,z) i) (* 2 (+ (* x 12) (* y 4) (* 1 z)))))) - (eval-res-1 (flat:build-tensor + (eval-res-1 (acc:build-tensor '(2 3 4) (λ (i) (match-define `(,x ,y ,z) i) @@ -83,7 +84,7 @@ (λ (i) (match-define `(,x ,y ,z) i) (* 2.0 (+ (* x 30) (* y 6) (* 1 z)))))) - (eval-res-1 (flat:build-tensor + (eval-res-1 (acc:build-tensor '(3 5 6) (λ (i) (match-define `(,x ,y ,z) i) @@ -98,7 +99,7 @@ (let ((tp (trefs (get-test-program 'tensor-r1-0) '(0 2)))) (+-ρ tp tp))) - (eval-res-1 (flat:tensor 2 6))) + (eval-res-1 (acc:tensor 2 6))) 'built-tensor (test-program-data (λ () (let ((test-build-shape '(4 3))) @@ -109,16 +110,16 @@ (+ (* (sub1 (car test-build-shape)) row) column)))))) - (eval-res-1 (flat:tensor (flat:tensor 0 1 2) - (flat:tensor 3 4 5) - (flat:tensor 6 7 8) - (flat:tensor 9 10 11)))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 3 4 5) + (acc:tensor 6 7 8) + (acc:tensor 9 10 11)))) 'multi-built-tensor (test-program-data (λ () (+-ρ (get-test-program 'build-tensor-r2-0) (tref (get-test-program 'build-tensor-r3-1) 0))) - (eval-res-1 ((flat:ext2-ρ * 0 0) 2 - (flat:build-tensor + (eval-res-1 ((acc:ext2-ρ * (λ (a b) "@{a} * @{b}") 0 0) 2 + (acc:build-tensor '(5 6) (λ (i) (match-define `(,x ,y) i) @@ -134,45 +135,45 @@ 'tcomp-list->tensor (test-program-data (λ () (make-list->tensor-test-program '(5 6 7 8))) - (eval-res-1 (flat:tensor 5 6 7 8))) + (eval-res-1 (acc:tensor 5 6 7 8))) 'tcomp-nested-list->tensor (test-program-data (λ () (list->tensor `(,(get-test-program 'tensor-r1-0) ,(get-test-program 'tensor-r1-0) ,(get-test-program 'tensor-r1-0)))) - (eval-res-1 (flat:tensor - (flat:tensor 1 2 3) - (flat:tensor 1 2 3) - (flat:tensor 1 2 3)))) + (eval-res-1 (acc:tensor + (acc:tensor 1 2 3) + (acc:tensor 1 2 3) + (acc:tensor 1 2 3)))) 'tcomp-trefs (test-program-data (λ () (trefs (get-test-program 'built-tensor) '(0 2))) - (eval-res-1 (flat:tensor (flat:tensor 0 1 2) - (flat:tensor 6 7 8)))) + (eval-res-1 (acc:tensor (acc:tensor 0 1 2) + (acc:tensor 6 7 8)))) 'tcomp-reshape (test-program-data (λ () (reshape '(3 2 1) (trefs (get-test-program 'built-tensor) '(1 3)))) - (eval-res-1 (flat:tensor (flat:tensor (flat:tensor 3) - (flat:tensor 4)) - (flat:tensor (flat:tensor 5) - (flat:tensor 9)) - (flat:tensor (flat:tensor 10) - (flat:tensor 11))))) + (eval-res-1 (acc:tensor (acc:tensor (acc:tensor 3) + (acc:tensor 4)) + (acc:tensor (acc:tensor 5) + (acc:tensor 9)) + (acc:tensor (acc:tensor 10) + (acc:tensor 11))))) 'sum (test-program-data (λ () (sum (get-test-program 'tensor-r2-0))) - (eval-res-1 (flat:tensor 6.0 15.0))) + (eval-res-1 (acc:tensor 6.0 15.0))) 'sum-nested (test-program-data (λ () (tensor 4.0 (sum (tensor 1 2 3)) 5.0)) - (eval-res-1 (flat:tensor 4.0 6.0 5.0))) + (eval-res-1 (acc:tensor 4.0 6.0 5.0))) 'id (test-program-data (λ () (id-ρ (get-test-program 'tensor-r2-0))) - (eval-res-1 (flat:tensor (flat:tensor 1 2 3) - (flat:tensor 4 5 6)))) + (eval-res-1 (acc:tensor (acc:tensor 1 2 3) + (acc:tensor 4 5 6)))) 'id-scalar (test-program-data (λ () (id-ρ (sum (tensor 4 5 6)))) @@ -185,9 +186,9 @@ (λ () (*-ρ (get-test-program 'build-tensor-r3-0) (get-test-program 'build-tensor-r3-0))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(2 3 4) - (flat:tensor + (acc:tensor 0 4 16 36 64 100 144 196 256 324 400 484 @@ -198,9 +199,9 @@ (λ () (*-2-1 (get-test-program 'build-tensor-r2-0) (get-test-program 'build-tensor-r1-0))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(5 6) - (flat:tensor + (acc:tensor 0 6.0 24.0 54.0 96.0 150.0 0 42.0 96.0 162.0 240.0 330.0 0 78.0 168.0 270.0 384.0 510.0 @@ -210,9 +211,9 @@ (λ () (*-2-1 (get-test-program 'build-tensor-r3-1) (get-test-program 'build-tensor-r2-1))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(3 5 6) - (flat:tensor + (acc:tensor 0 6.0 24.0 54.0 96.0 150.0 0 42.0 96.0 162.0 240.0 330.0 0 78.0 168.0 270.0 384.0 510.0 @@ -238,14 +239,14 @@ 'tcomp-dsqr-r1 (test-program-data (λ () (d-sqr r1-td (one-like r1-td))) - (eval-res-1 (flat:tensor 6.0 8.0 10.0))) + (eval-res-1 (acc:tensor 6.0 8.0 10.0))) 'gsqr (test-program-data (λ () (let ([r2-td (get-test-program 'tensor-r2-1)]) (d-sqr r2-td (one-like r2-td)))) - (eval-res-1 (flat:reshape + (eval-res-1 (acc:reshape '(2 3) - (flat:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) + (acc:tensor 6.0 8.0 10.0 14.0 16.0 18.0)))) 'g+ (test-program-data (λ () (d+ 2.0 3.0 1.0)) @@ -253,42 +254,42 @@ 'g-twice (test-program-data (λ () (d+ r1-td r1-td (one-like r1-td))) - (eval-res-2 (flat:tensor 1.0 1.0 1.0) - (flat:tensor 1.0 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) 'g+-r1-r2 (test-program-data (λ () (let ((r2-td (get-test-program 'tensor-r2-1))) (d+ r1-td r2-td (one-like r2-td)))) - (eval-res-2 (flat:tensor 2.0 2.0 2.0) - (flat:reshape + (eval-res-2 (acc:tensor 2.0 2.0 2.0) + (acc:reshape '(2 3) - (flat:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) + (acc:tensor 1.0 1.0 1.0 1.0 1.0 1.0)))) 'g* (test-program-data (λ () (*∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0 1.0))) - (eval-res-2 (flat:tensor 1.0 2.0 3.0) - (flat:tensor 2.0 3.0 4.0))) + (eval-res-2 (acc:tensor 1.0 2.0 3.0) + (acc:tensor 2.0 3.0 4.0))) 'gsum-r1 (test-program-data (λ () (sum-∇ (tensor 2.0 3.0 4.0) 1.0)) - (eval-res-1 (flat:tensor 1.0 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0 1.0))) 'gsum-r2 (test-program-data (λ () (sum-∇ (tensor (tensor 2.0 3.0 4.0) (tensor 2.0 3.0 4.0)) (tensor 2.0 1.0))) - (eval-res-1 (flat:tensor (flat:tensor 2.0 2.0 2.0) - (flat:tensor 1.0 1.0 1.0)))) + (eval-res-1 (acc:tensor (acc:tensor 2.0 2.0 2.0) + (acc:tensor 1.0 1.0 1.0)))) 'gs2-r1 (test-program-data (λ () (s2-∇ (tensor 2.0 3.0 4.0) (tensor 1.0 2.0 3.0) (tensor 1.0 1.0))) - (eval-res-2 (flat:tensor 1.0 1.0 1.0) - (flat:tensor 1.0 1.0 1.0))) + (eval-res-2 (acc:tensor 1.0 1.0 1.0) + (acc:tensor 1.0 1.0 1.0))) 'gs2-r3 (test-program-data (λ () (s2-∇ (tensor (tensor (tensor 1.0 2.0 6.0) @@ -309,20 +310,20 @@ (tensor 1.0 1.0)) (tensor (tensor 1.0 1.0) (tensor 1.0 1.0))))) - (eval-res-2 (flat:reshape '(3 2 3) - (flat:list->tensor (make-list 18 1.0))) - (flat:reshape '(3 2 3) - (flat:list->tensor (make-list 18 1.0))))) + (eval-res-2 (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))) + (acc:reshape '(3 2 3) + (acc:list->tensor (make-list 18 1.0))))) 'env-flat-scalar (test-program-data (λ () ((λ (theta) (*-ρ (list-ref theta 0) (list-ref theta 1))) (list (tensor 1.0) 3.0))) - (eval-res-1 (flat:tensor 3.0))) + (eval-res-1 (acc:tensor 3.0))) 'common-subexpression (test-program-data (λ () (let ((t (tref (tensor 1 2 3) 0))) (tensor t t))) - (eval-res-1 (flat:tensor 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0))) 'nested-common-subexpression (test-program-data (λ () (let ((t1 (tref (tensor (tensor 1 2 3) @@ -330,7 +331,7 @@ 0))) (let ((t0 (tref t1 0))) (tensor t0 t0)))) - (eval-res-1 (flat:tensor 1.0 1.0))) + (eval-res-1 (acc:tensor 1.0 1.0))) )) (define get-test-program @@ -345,13 +346,24 @@ (vset! out-v iₒ (for/fold ([sum 0.0]) ([i (in-range iᵢ (+ iᵢ sᵢ))]) (+ sum (vref in-v i)))))) - -(define sum (ext1-ρ sum-f 1 (λ (s) '()) #t)) +(define sum-f-acc + (λ (v0 i0 stride0 + v-out i-out stride-out) + #< (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural) (Vector Number) Natural (Listof Natural)))) -(struct tcomp-ext1-ρ-scalar tcomp (f sign tp) #:transparent) -(struct tcomp-ext1-ρ tcomp (f sign m shape-fn tp) #:transparent) -(struct tcomp-ext2-ρ-scalar tcomp (f sign tp-t tp-u) #:transparent) -(struct tcomp-ext2-ρ tcomp (tp-t tp-u f sign m n shape-fn) #:transparent) -(struct tcomp-ext1-∇ tcomp (tp zp f sign m shape-fn) #:transparent) -(struct tcomp-ext2-∇ tcomp (fᵈ +(struct tcomp-ext1-ρ-scalar tcomp (f f-acc sign tp) #:transparent) +(struct tcomp-ext1-ρ tcomp (f f-acc sign m shape-fn tp) #:transparent) +(struct tcomp-ext2-ρ-scalar tcomp (f f-acc sign tp-t tp-u) #:transparent) +(struct tcomp-ext2-ρ tcomp (tp-t tp-u f f-acc sign m n shape-fn) #:transparent) +(struct tcomp-ext1-∇ tcomp (tp zp f f-acc sign m shape-fn) #:transparent) +(struct tcomp-ext2-∇ tcomp (fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) @@ -98,7 +98,7 @@ (define gdst-trefs (λ (tp i-lst) - (box (list (tpromise-dst tp) (flat:list->tensor i-lst))))) + (box (list (tpromise-dst tp) (acc:list->tensor i-lst))))) (define gdst-ext2-∇ (λ (tp-t0 tp-t1 tp-z) @@ -216,24 +216,24 @@ (gs-trefs tp)))) (define tpmake-ext1-ρ-scalar - (λ (f signature tp shape) - (tpromise (tcomp-ext1-ρ-scalar f signature tp) shape + (λ (f f-acc prim-sign tp shape) + (tpromise (tcomp-ext1-ρ-scalar f f-acc prim-sign tp) shape (box (list (tpromise-dst tp))) - (gs-ext1-ρ-scalar signature tp)))) + (gs-ext1-ρ-scalar prim-sign tp)))) (define tpmake-ext1-ρ - (λ (f signature m shape-fn tp shape) - (tpromise (tcomp-ext1-ρ f signature m shape-fn tp) + (λ (f f-acc prim-sign m shape-fn tp shape) + (tpromise (tcomp-ext1-ρ f f-acc prim-sign m shape-fn tp) shape (box (list (tpromise-dst tp))) - (gs-ext1-ρ signature m tp)))) + (gs-ext1-ρ prim-sign m tp)))) (define tpmake-ext2-ρ-scalar - (λ (f signature tp-t tp-u shape) - (tpromise (tcomp-ext2-ρ-scalar f signature tp-t tp-u) + (λ (f f-acc prim-sign tp-t tp-u shape) + (tpromise (tcomp-ext2-ρ-scalar f f-acc prim-sign tp-t tp-u) shape (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) - (gs-ext2-ρ-scalar signature tp-t tp-u)))) + (gs-ext2-ρ-scalar prim-sign tp-t tp-u)))) (define ensure-tpromise (λ (v) @@ -243,14 +243,14 @@ (else v)))) (define tpmake-ext2-ρ - (λ (tp-t tp-u f signature m n shape-fn shape) + (λ (tp-t tp-u f f-acc prim-sign m n shape-fn shape) (let ((tp-t (ensure-tpromise tp-t)) (tp-u (ensure-tpromise tp-u))) (tpromise - (tcomp-ext2-ρ tp-t tp-u f signature m n shape-fn) + (tcomp-ext2-ρ tp-t tp-u f f-acc prim-sign m n shape-fn) shape (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) - (gs-ext2-ρ signature m n tp-t tp-u))))) + (gs-ext2-ρ prim-sign m n tp-t tp-u))))) ;; we invoke ensure-tpromise on just zp because it's the result of calling ;; force*1 which forces zp to be a non-tpromise value. We can ensure tp to @@ -258,25 +258,25 @@ ;; before passing it to this function, nor do we need scalar tp to be wrapped in ;; a tpromise. (define tpmake-ext1-∇ - (λ (tp zp f signature m shape-fn shape) + (λ (tp zp f f-acc prim-sign m shape-fn shape) (let ((zp (ensure-tpromise zp))) (tpromise - (tcomp-ext1-∇ tp zp f signature m shape-fn) + (tcomp-ext1-∇ tp zp f f-acc prim-sign m shape-fn) shape (box (list (tpromise-dst tp) (tpromise-dst zp))) - (gs-ext1-∇ signature m tp zp))))) + (gs-ext1-∇ prim-sign m tp zp))))) (define tpmake-ext2-∇ - (λ (fᵈ signature r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) + (λ (fᵈ fᵈ-acc prim-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i shape) (let ((tp-t0 (ensure-tpromise tp-t0)) (tp-t1 (ensure-tpromise tp-t1)) (tp-z (ensure-tpromise tp-z))) (tpromise - (tcomp-ext2-∇ fᵈ signature r0 r1 shape-fn + (tcomp-ext2-∇ fᵈ fᵈ-acc prim-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) shape (gdst-ext2-∇ tp-t0 tp-t1 tp-z) - (gs-ext2-∇ signature r0 r1 tp-t0 tp-t1 tp-z i))))) + (gs-ext2-∇ prim-sign r0 r1 tp-t0 tp-t1 tp-z i))))) (define tpmake-reshape (λ (tp shape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 0600e04..9fad17f 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -1,26 +1,28 @@ #lang racket -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) +(require ffi/vector) +(require "../../impl-loader.rkt") +(require "../../accelerated-tensors/ext-impl.rkt") +(require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (struct ext2-∇-result (res) #:mutable #:transparent) (define ext2-∇-forcer! - (λ (fᵈ r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) + (λ (fᵈ fᵈ-acc fᵈ-sign r0 r1 shape-fn t0 t1 z out-idx0 out-idx1) (let* ((f0 (ensure-flat t0)) (f1 (ensure-flat t1)) (fz (ensure-flat z)) (s0 (flat-shape f0)) (sf0 (min-shape r0 s0)) - (stride0 (flat:size-of sf0)) + (stride0 (acc:size-of sf0)) (s1 (flat-shape t1)) (sf1 (min-shape r1 s1)) - (stride1 (flat:size-of sf1)) + (stride1 (acc:size-of sf1)) (sf-z (shape-fn sf0 sf1)) - (stride-z (flat:size-of sf-z)) + (stride-z (acc:size-of sf-z)) (v0 (flat-store f0)) (v1 (flat-store f1)) @@ -32,29 +34,22 @@ (ext2-shapes s0 s1 r0 r1 sf-z (λ (sz size-z q0 q1 strides) - (let ((g0 (new-vec (flat:size-of - s0) - 0.0)) - (g1 (new-vec (flat:size-of - s1) - 0.0))) - (for ([iz (in-range - 0 - size-z - stride-z)]) - (let-values (((i0 i1) - (idxs - strides - iz - off0 - off1))) - (fᵈ g0 g1 v0 i0 - stride0 - v1 i1 - stride1 - vz - (+ offz iz) - stride-z))) + (let ((g0 (new-vec (acc:size-of s0) 0.0)) + (g1 (new-vec (acc:size-of s1) 0.0))) + (cond + ((accelerate?) + (let-values (((kernel-code kernel-name) + (ext2-∇-kernel/name fᵈ-acc fᵈ-sign strides s0 s1 r0 r1 sz + (length sf-z)))) + (run-prim2-∇! kernel-code kernel-name + g0 g1 + v0 off0 (acc:size-of s0) stride0 + v1 off1 (acc:size-of s1) stride1 + vz offz size-z stride-z))) + (else + (for ([iz (in-range 0 size-z stride-z)]) + (let-values (((i0 i1) (idxs strides iz off0 off1))) + (fᵈ g0 g1 v0 i0 stride0 v1 i1 stride1 vz (+ offz iz) stride-z))))) (when out-idx0 (data-segment-set! out-idx0 (scalarize (flat s0 g0 0)))) (when out-idx1 @@ -63,7 +58,7 @@ (define rt:trefs (λ (ft b) (cond - ((= (flat:rank b) 1) (flat:trefs ft (vector->list (flat-store b)))) + ((= (acc:rank b) 1) (acc:trefs ft (map inexact->exact (f32vector->list (flat-store b))))) (else (error 'trefs-err "~a should be a tensor¹" b))))) (define data-segment @@ -81,7 +76,7 @@ (define runtime (namespace-anchor->namespace a)) -(provide runtime flat? flat:build-tensor flat:list->tensor - flat:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! +(provide runtime flat? acc:build-tensor acc:list->tensor + acc:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index 423029c..ebd84c4 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -2,8 +2,8 @@ (require "c0-ast.rkt") (require (only-in "c1-racket-runtime.rkt" - runtime flat? flat:build-tensor flat:list->tensor - set-ext2-∇-result-res! flat:tref rt:trefs ext2-∇-result-res + runtime flat? acc:build-tensor acc:list->tensor + set-ext2-∇-result-res! acc:tref rt:trefs ext2-∇-result-res ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref)) @@ -18,14 +18,14 @@ ((tpromise? arg) (interp-tpromise arg env)) ((number? arg) arg) (else (error 'interp-list->tensor "Unexpected: ~a" arg)))))) - (flat:list->tensor eval-list))] + (acc:list->tensor eval-list))] [(tcomp-tref tp (and i (tcomp-ds-ref _))) - (flat:tref (interp-tpromise tp env) + (acc:tref (interp-tpromise tp env) (interp-tcomp i env))] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (rt:trefs (interp-tpromise tp env) (interp-tcomp b env))] - [(tcomp-ext2-∇ fᵈ _ r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc f-sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let ((t0-instrs (interp-tpromise tp-t0 env)) (t1-instrs (interp-tpromise tp-t1 env)) (z-instrs (interp-tpromise tp-z env))) @@ -44,28 +44,28 @@ (v (data-segment-ref index))) (cond ((eqv? v 'uncalculated) - (ext2-∇-forcer! fᵈ r0 r1 shape-fn + (ext2-∇-forcer! fᵈ fᵈ-acc f-sign r0 r1 shape-fn t0-instrs t1-instrs z-instrs out-idx0 out-idx1) (data-segment-ref index)) (else v))))] - [(tcomp-ext1-∇ tp zp f _ m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc f-sign m shape-fn) (scalarize - (flat-ext1-∇ f m shape-fn + (flat-ext1-∇ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env)) (ensure-flat (interp-tpromise zp env))))] - [(tcomp-ext2-ρ-scalar f _ tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc _ tp-t tp-u) (f (interp-tpromise tp-t env) (interp-tpromise tp-u env))] - [(tcomp-ext2-ρ tp-t tp-u f _ m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc f-sign m n shape-fn) (scalarize - (flat-ext2-ρ f m n shape-fn + (flat-ext2-ρ f f-acc m n shape-fn f-sign (ensure-flat (interp-tpromise tp-t env)) (ensure-flat (interp-tpromise tp-u env))))] - [(tcomp-ext1-ρ-scalar f _ tp) + [(tcomp-ext1-ρ-scalar f f-acc _ tp) (f (interp-tpromise tp env))] - [(tcomp-ext1-ρ f _ m shape-fn tp) + [(tcomp-ext1-ρ f f-acc f-sign m shape-fn tp) (scalarize - (flat-ext1-ρ f m shape-fn + (flat-ext1-ρ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env))))] [(tcomp-reshape s tp) (let ([interp-tp (interp-tpromise tp env)]) diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 6b8437f..594eb37 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -2,8 +2,6 @@ (require "c0-ast.rkt") (require (only-in "c2-interpreter.rkt" interp-tensor interp-racket)) -(require "../../flat-tensors/ext-impl.rkt") -(require (prefix-in flat: "../../flat-tensors/tensors.rkt")) (require (only-in "c1-racket-runtime.rkt" runtime ext2-∇-result-res set-ext2-∇-result-res!)) @@ -67,7 +65,7 @@ [(tcomp-trefs tp (tcomp-ds-ref #f)) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-trefs tp^ (tcomp-ds-ref ref^)) (add1 ref^)))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) @@ -80,27 +78,27 @@ ((and (eqv? i 1) (not (tcomp-ds-ref-index (ext2-∇-result-res out-ref1)))) (set-ext2-∇-result-res! out-ref1 (tcomp-ds-ref ref^^^)))) - (values (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ + (values (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0^ tp-t1^ tp-z^ out-ref0 out-ref1 i) (add1 ref^^^)))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) - (values (tcomp-ext1-∇ tp^ zp^ f sign m shape-fn) ref^^))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + (values (tcomp-ext1-∇ tp^ zp^ f f-acc sign m shape-fn) ref^^))] + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) - (values (tcomp-ext2-ρ-scalar f sign tp-t^ tp-u^) ref^^))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + (values (tcomp-ext2-ρ-scalar f f-acc sign tp-t^ tp-u^) ref^^))] + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) - (values (tcomp-ext2-ρ tp-t^ tp-u^ f sign m n shape-fn) ref^^))] - [(tcomp-ext1-ρ-scalar f sign tp) + (values (tcomp-ext2-ρ tp-t^ tp-u^ f f-acc sign m n shape-fn) ref^^))] + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) - (values (tcomp-ext1-ρ-scalar f sign tp^) ref^))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + (values (tcomp-ext1-ρ-scalar f f-acc sign tp^) ref^))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) - (values (tcomp-ext1-ρ f sign m shape-fn tp^) ref^))] + (values (tcomp-ext1-ρ f f-acc sign m shape-fn tp^) ref^))] [(tcomp-reshape s tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-reshape s tp^) ref^))] @@ -173,22 +171,22 @@ (cr-tpromise tp counter^ uid^)] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (cr-tpromise tp counter^ uid^)] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) (cr-tpromise tp-t1 counter-2 uid-2))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) (cr-tpromise zp counter-1 uid-1))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) (cr-tpromise tp-u counter-1 uid-1))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) (cr-tpromise tp-u counter-1 uid-1))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f _ sign tp) (cr-tpromise tp counter^ uid^)] - [(tcomp-ext1-ρ f sign m shape-fn tp) + [(tcomp-ext1-ρ f _ sign m shape-fn tp) (cr-tpromise tp counter^ uid^)] [(tcomp-reshape s tp) (cr-tpromise tp counter^ uid^)] @@ -263,7 +261,7 @@ (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-trefs instrs b) tc-counter-data)))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (->ecs (ecs-tpromise tp-t0 counter) (λ (t0-instrs) @@ -274,11 +272,11 @@ (ecs-tpromise tp-z counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn + (tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn t0-instrs t1-instrs z-instrs out0 out1 i) tc-counter-data)))))))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (->ecs (ecs-tpromise tp counter) (λ (t-instrs) @@ -286,9 +284,9 @@ (ecs-tpromise zp counter) (λ (z-instrs) (inj-ecs-tcomp - (tcomp-ext1-∇ t-instrs z-instrs f sign m shape-fn) + (tcomp-ext1-∇ t-instrs z-instrs f f-acc sign m shape-fn) tc-counter-data)))))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (->ecs (ecs-tpromise tp-t counter) (λ (t-instrs) @@ -296,9 +294,9 @@ (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ-scalar f sign t-instrs u-instrs) + (tcomp-ext2-ρ-scalar f f-acc sign t-instrs u-instrs) tc-counter-data)))))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (->ecs (ecs-tpromise tp-t counter) (λ (t-instrs) @@ -306,18 +304,18 @@ (ecs-tpromise tp-u counter) (λ (u-instrs) (inj-ecs-tcomp - (tcomp-ext2-ρ t-instrs u-instrs f sign m n shape-fn) + (tcomp-ext2-ρ t-instrs u-instrs f f-acc sign m n shape-fn) tc-counter-data)))))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (->ecs (ecs-tpromise tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f sign instrs) tc-counter-data)))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + (inj-ecs-tcomp (tcomp-ext1-ρ-scalar f f-acc sign instrs) tc-counter-data)))] + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (->ecs (ecs-tpromise tp counter) (λ (instrs) - (inj-ecs-tcomp (tcomp-ext1-ρ f sign m shape-fn instrs) tc-counter-data)))] + (inj-ecs-tcomp (tcomp-ext1-ρ f f-acc sign m shape-fn instrs) tc-counter-data)))] [(tcomp-reshape s tp) (->ecs (ecs-tpromise tp counter) @@ -389,16 +387,16 @@ ((number? t) t) (else (error 'gr-list->tensor "Unexpected: ~a" t)))) lst))) - `(flat:list->tensor (list ,@instrs-list)))] + `(acc:list->tensor (list ,@instrs-list)))] [(tcomp-tref tp (and i (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (i-instrs (gr-tcomp i))) - `(flat:tref ,instrs ,i-instrs))] + `(acc:tref ,instrs ,i-instrs))] [(tcomp-trefs tp (and b (tcomp-ds-ref _))) (let ((instrs (gr-tpromise tp)) (b-instrs (gr-tcomp b))) `(rt:trefs ,instrs ,b-instrs))] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + [(tcomp-ext2-∇ fᵈ fᵈ-acc sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out0 out1 i) (let ((t0-instrs (gr-tpromise tp-t0)) (t1-instrs (gr-tpromise tp-t1)) (z-instrs (gr-tpromise tp-z)) @@ -409,36 +407,36 @@ [v (data-segment-ref index)]) (cond ((eqv? v 'uncalculated) - (ext2-∇-forcer! ,fᵈ ,r0 ,r1 ,shape-fn + (ext2-∇-forcer! ,fᵈ ,fᵈ-acc ,sign ,r0 ,r1 ,shape-fn ,t0-instrs ,t1-instrs ,z-instrs ,out-idx0 ,out-idx1) (data-segment-ref index)) (else v)))))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f f-acc sign m shape-fn) (let ((t-instrs (gr-tpromise tp)) (z-instrs (gr-tpromise zp))) `(scalarize - (flat-ext1-∇ ,f ,m ,shape-fn + (flat-ext1-∇ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,t-instrs) (ensure-flat ,z-instrs))))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f f-acc sign tp-t tp-u) (let ((t-instrs (gr-tpromise tp-t)) (u-instrs (gr-tpromise tp-u))) `(,f ,t-instrs ,u-instrs))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f f-acc sign m n shape-fn) (let ((t-instrs (gr-tpromise tp-t)) (u-instrs (gr-tpromise tp-u))) `(scalarize - (flat-ext2-ρ ,f ,m ,n ,shape-fn + (flat-ext2-ρ ,f ,f-acc ,m ,n ,shape-fn ,sign (ensure-flat ,t-instrs) (ensure-flat ,u-instrs))))] - [(tcomp-ext1-ρ-scalar f sign tp) + [(tcomp-ext1-ρ-scalar f f-acc sign tp) (let ((instrs (gr-tpromise tp))) `(,f ,instrs))] - [(tcomp-ext1-ρ f sign m shape-fn tp) + [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let ((instrs (gr-tpromise tp))) `(scalarize - (flat-ext1-ρ ,f ,m ,shape-fn + (flat-ext1-ρ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,instrs))))] [(tcomp-reshape s tp) (let ((instrs (gr-tpromise tp))) diff --git a/lazy/tensors/test/test-1-reflect.rkt b/lazy/tensors/test/test-1-reflect.rkt index ce4481e..e266e65 100644 --- a/lazy/tensors/test/test-1-reflect.rkt +++ b/lazy/tensors/test/test-1-reflect.rkt @@ -1,5 +1,6 @@ (module+ test (require rackunit) + (require ffi/vector) (require "0-lazy.rkt") (require "B-test-programs.rkt") @@ -14,33 +15,33 @@ ((eval-res-1 res) (let* ((tp (th)) (forced (↓ tp))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? forced res (format "Expected result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp) - (check-equal? (tpromise-shape tp) (flat:shape forced)))) + (check-equal? (tpromise-shape tp) (acc:shape forced)))) ((eval-res-2 res1 res2) (let*-values (((tp1 tp2) (th)) ((forced1) (↓ tp1)) ((forced2) (↓ tp2))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? forced1 res1 (format "Expected first result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp1) - (check-equal? (tpromise-shape tp1) (flat:shape forced1)) - (flat:check-tensor-equal? + (check-equal? (tpromise-shape tp1) (acc:shape forced1)) + (acc:check-tensor-equal? forced2 res2 (format "Expected second result doesn't match in test case ~a" test-name)) (check-pred evaluated-tpromise? tp2) - (check-equal? (tpromise-shape tp2) (flat:shape forced2)))))) + (check-equal? (tpromise-shape tp2) (acc:shape forced2)))))) (define test-tensor-r1-0 (get-test-program 'tensor-r1-0)) - (check-false (flat:flat? (tpromise-tensor test-tensor-r1-0))) - (check-true (flat:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) + (check-false (acc:flat? (tpromise-tensor test-tensor-r1-0))) + (check-true (acc:flat? (car (unbox (tpromise-dst test-tensor-r1-0))))) (check-exn exn:fail? (λ () (tensor test-tensor-r1-0 4))) (check-exn exn:fail? (λ () (tensor 4 test-tensor-r1-0))) @@ -72,7 +73,7 @@ test-tensor-r1-0))) 1) 2))) - (flat:check-tensor-equal? (↓ test-tcomp-partial-eval) + (acc:check-tensor-equal? (↓ test-tcomp-partial-eval) (↓ (tensor 1 2 3))) (define test-id-scalar (get-test-program 'id-scalar)) @@ -80,7 +81,7 @@ (+-ρ test-id-scalar (get-test-program 'sum-nested))) (void (↓ test-id-scalar)) - (flat:check-tensor-equal? (↓ test-force-scalar) + (acc:check-tensor-equal? (↓ test-force-scalar) (↓ (tensor 19 21 20))) (define test-force-subexpr @@ -91,13 +92,13 @@ (+-ρ (get-test-program 'sum-nested) (get-test-program 'sum-nested)))) (void (↓ test-force-subexpr)) - (flat:check-tensor-equal? (↓ test-force-mutate) + (acc:check-tensor-equal? (↓ test-force-mutate) (↓ (tensor 27 33 30))) (define test-tp-r1 (tensor -1 -2 -3)) (define test-force-supexpr (abs-ρ test-tp-r1)) (void (↓ test-force-supexpr)) - (flat:check-tensor-equal? (↓ test-tp-r1) + (acc:check-tensor-equal? (↓ test-tp-r1) (↓ (tensor -1 -2 -3))) (define test-trefs (get-test-program 'tcomp-trefs)) @@ -109,9 +110,9 @@ (check-pred (λ (fs) (andmap (λ (e) (integer? (sqrt e))) fs)) - (vector->list (flat:flat-store (↓ test-build-random))) + (f32vector->list (acc:flat-store (↓ test-build-random))) "Side-effect of generating random tensor must only be run once") - (flat:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) + (acc:check-tensor-equal? (↓ (get-test-program 'multi-built-tensor)) (eval-res-1-res (get-test-eval-res 'multi-built-tensor))) ) diff --git a/lazy/tensors/test/test-c2-interpreter.rkt b/lazy/tensors/test/test-c2-interpreter.rkt index c20dfdb..be70371 100644 --- a/lazy/tensors/test/test-c2-interpreter.rkt +++ b/lazy/tensors/test/test-c2-interpreter.rkt @@ -1,7 +1,7 @@ (module+ test (require rackunit) (require "B-test-programs.rkt") - (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (for (((test-name test-data) (in-hash test-programs))) (match-define (test-program-data th res) test-data) @@ -9,24 +9,24 @@ ((eval-res-1 res) (let* ((tp (th)) (interped (interp-tensor tp))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? interped res (format "Expected result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp) (flat:shape interped)))) + (check-equal? (tpromise-shape tp) (acc:shape interped)))) ((eval-res-2 res1 res2) (let*-values (((tp1 tp2) (th)) ((interped1) (interp-tensor tp1)) ((interped2) (interp-tensor tp2))) - (flat:check-tensor-equal? + (acc:check-tensor-equal? interped1 res1 (format "Expected first result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp1) (flat:shape interped1)) - (flat:check-tensor-equal? + (check-equal? (tpromise-shape tp1) (acc:shape interped1)) + (acc:check-tensor-equal? interped2 res2 (format "Expected second result doesn't match in test case ~a" test-name)) - (check-equal? (tpromise-shape tp2) (flat:shape interped2)))))) + (check-equal? (tpromise-shape tp2) (acc:shape interped2)))))) ) diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index 0585290..e7be52a 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -3,7 +3,7 @@ (require "B-test-programs.rkt") (require "0-lazy.rkt") (require "c2-interpreter.rkt") - (require (prefix-in flat: "../../flat-tensors/tensors.rkt")) + (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define current-test-program-name (make-parameter #f)) (define-check (check-compiler-invariants tp) @@ -18,7 +18,7 @@ ('test-name (current-test-program-name))) (for ((d ds)) (unless (or (number? d) - (flat:flat? d) + (acc:flat? d) (eqv? d 'uncalculated)) (fail-check (format (string-append "Data segment should only contain flat tensors " ", the symbol 'uncalculated or numbers." @@ -27,34 +27,34 @@ (parameterize ((cache (make-hash))) (let* ((instrs-dsr (generate-ds-refs tp)) (interp-dsr (interp-tensor instrs-dsr))) - (unless (flat:tensor-equal? interp-dsr interp-tp) + (unless (acc:tensor-equal? interp-dsr interp-tp) (fail-check (format (string-append "Result of interpreting pass generate-ds-ref doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-dsr)) + "interpretation: ~a~n") + interp-dsr))) (let ((counter (count-references instrs-dsr))) (let* ((extracted (extract-common-subexpressions instrs-dsr counter)) (interp-extracted (interp-tensor extracted))) - (unless (flat:tensor-equal? interp-extracted interp-tp) + (unless (acc:tensor-equal? interp-extracted interp-tp) (fail-check (format (string-append "Result of interpreting pass" " extract-common-subexpression doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-extracted)) + "interpretation: ~a~n") + interp-extracted))) (let* ((gr (generate-racket extracted)) (rkt (compile-racket gr)) (interp-rkt (interp-racket rkt ds))) - (unless (flat:tensor-equal? interp-rkt interp-tp) + (unless (acc:tensor-equal? interp-rkt interp-tp) (fail-check (format (string-append "Result of interpreting compiled racket code doesn't" " match expected interpretation. Actual " - "interpretation: ~a~n")) - interp-rkt)) + "interpretation: ~a~n") + interp-rkt))) (hash-set! (cache) signature rkt) (compile-tensor tp) (unless (eqv? (hash-count (cache)) 1) @@ -167,27 +167,27 @@ (else (error 'cdsr-list->tensor "Unexpected: ~a" l))))] [(tcomp-tref tp _) (count-tcomp-var tp)] [(tcomp-trefs tp _) (count-tcomp-var tp)] - [(tcomp-ext2-∇ fᵈ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z + [(tcomp-ext2-∇ fᵈ _ sign r0 r1 shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) (let ((c0 (count-tcomp-var tp-t0)) (c1 (count-tcomp-var tp-t1)) (cz (count-tcomp-var tp-z))) (+ c0 c1 cz))] - [(tcomp-ext1-∇ tp zp f sign m shape-fn) + [(tcomp-ext1-∇ tp zp f _ sign m shape-fn) (let ((ct (count-tcomp-var tp)) (cz (count-tcomp-var zp))) (+ ct cz))] - [(tcomp-ext2-ρ-scalar f sign tp-t tp-u) + [(tcomp-ext2-ρ-scalar f _ sign tp-t tp-u) (let ((ct (count-tcomp-var tp-t)) (cu (count-tcomp-var tp-u))) (+ ct cu))] - [(tcomp-ext2-ρ tp-t tp-u f sign m n shape-fn) + [(tcomp-ext2-ρ tp-t tp-u f _ sign m n shape-fn) (let ((ct (count-tcomp-var tp-t)) (cu (count-tcomp-var tp-u))) (+ ct cu))] - [(tcomp-ext1-ρ-scalar f sign tp) (count-tcomp-var tp)] - [(tcomp-ext1-ρ f sign m shape-fn tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ-scalar f _ sign tp) (count-tcomp-var tp)] + [(tcomp-ext1-ρ f _ sign m shape-fn tp) (count-tcomp-var tp)] [(tcomp-reshape s tp) (count-tcomp-var tp)] [(tcomp-ds-ref i) 0] [(tcomp-let lhs rhs body)