diff --git a/lazy/autodiff/B-prims.rkt b/lazy/autodiff/B-prims.rkt index e14ffab..5b4b252 100644 --- a/lazy/autodiff/B-prims.rkt +++ b/lazy/autodiff/B-prims.rkt @@ -1,12 +1,13 @@ #lang racket -(require (only-in "../../accelerated-tensors/ext-impl.rkt" - new-vec - apply-flat-ρ-fn-1 - apply-flat-ρ-fn-2 - apply-flat-∇-fn-1 - apply-flat-∇-fn-2)) +(require (only-in "../tensors/c0-ast.rkt" + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇)) (require "../tensors.rkt") +(require (only-in "../tensors/c1-racket-runtime.rkt" ext2-∇-result)) +(require (only-in "../tensors/c0-ast.rkt" tcomp-ds-ref)) (require "A-autodiff.ss") (struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc) @@ -21,8 +22,12 @@ (set! id (add1 id)) (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da) - (prim1-dual (if #;#f expects-prealloc? (preallocated->functional-1-ρ ρ-fn shape) ρ-fn) - (if #;#f expects-prealloc? (preallocated->functional-1-∇ ∇-fn shape) ∇-fn) + (prim1-dual (if expects-prealloc? + (preallocated->functional-1-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-1-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) da))))))) ;; TODO: Convert the use of force* into the construction of an AST so that we @@ -43,8 +48,12 @@ (set! id (add1 id)) (prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc? (λ (da db) - (prim2-dual (if expects-prealloc? (preallocated->functional-2-ρ ρ-fn shape) ρ-fn) - (if expects-prealloc? (preallocated->functional-2-∇ ∇-fn shape) ∇-fn) + (prim2-dual (if expects-prealloc? + (preallocated->functional-2-ρ ρ-fn ρ-acc-fn prim-sign shape) + ρ-fn) + (if expects-prealloc? + (preallocated->functional-2-∇ ∇-fn ∇-acc-fn prim-sign shape) + ∇-fn) da db))))))) (define prim2-dual @@ -64,41 +73,28 @@ ;;---------------------------- (define preallocated->functional-1-ρ - (λ (ρ-fn shape-fn) + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) (λ (ra) - (force*1 ra - (λ (ra) - (apply-flat-ρ-fn-1 ρ-fn ra shape-fn)))))) + (tpmake-prim1-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra)))) (define preallocated->functional-1-∇ - (λ (∇-fn shape-fn) + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) (λ (ra z) - (force*2 - (λ () - (values ra z)) - (λ (ra z) - (apply-flat-∇-fn-1 ∇-fn ra z shape-fn)))))) + (tpmake-prim1-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra z)))) (define preallocated->functional-2-ρ - (λ (ρ-fn shape-fn) + (λ (ρ-fn ρ-fn-acc prim-sign shape-fn) (λ (ra rb) - (force*2 - (λ () - (values ra rb)) - (λ (ra rb) - (apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn)))))) + (tpmake-prim2-ρ ρ-fn ρ-fn-acc prim-sign shape-fn ra rb)))) (define preallocated->functional-2-∇ - (λ (∇-fn shape-fn) + (λ (∇-fn ∇-fn-acc prim-sign shape-fn) (λ (ra rb z) - (force*2 - (λ () - (values ra rb)) - (λ (ra rb) - (force*1 - z - (λ (z) - (apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn)))))))) + (let ((out-ref0 (ext2-∇-result (tcomp-ds-ref #f))) + (out-ref1 (ext2-∇-result (tcomp-ds-ref #f)))) + (values + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 0) + (tpmake-prim2-∇ ∇-fn ∇-fn-acc prim-sign shape-fn ra rb z out-ref0 out-ref1 1)))))) ;;---------------------------- ;; Dualized tensor op creators diff --git a/lazy/autodiff/D-test-helpers.rkt b/lazy/autodiff/D-test-helpers.rkt index dde0e59..a680951 100644 --- a/lazy/autodiff/D-test-helpers.rkt +++ b/lazy/autodiff/D-test-helpers.rkt @@ -1,10 +1,20 @@ #lang racket (require "../tensors.rkt") +(require "../tensors/c0-ast.rkt") (require "A-autodiff.ss") +(require (except-in "../../accelerated-tensors/ext-impl.rkt" + scalarize)) (require rackunit) +(define force-print-store + (λ (t) + (with-output-to-string + (λ () + (print-vec (flat-store (↓ t) + #;(list-ref (unbox (tpromise-dst t)) 0))))))) + (define-binary-check (check-dual-equal? equal-wt? actual expected)) (define-check (ρ-∇-checker fn args ans grads) (let* ((y (apply fn args)) @@ -14,11 +24,11 @@ ((and (equal-wt? ans-ρ (ρ y)) (equal-wt? grads (ρ g))) (void)) ((equal-wt? ans-ρ (ρ y)) - (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%" - (ρ g) grads))) + (fail-check (format "Gradients failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ g) grads (map force-print-store (ρ g)) (map force-print-store grads)))) (else - (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%" - (ρ y) ans-ρ)))))) + (fail-check (format "Answers failed to match.~%actual:~%~s~%expected:~%~s~%~%actual store:~%~a~%expected store:~%~a~%" + (ρ y) ans-ρ (force-print-store (ρ y)) (force-print-store ans-ρ))))))) (define-syntax check-ρ-∇ (syntax-rules () diff --git a/lazy/ext-ops/test/test-C-star-2-1.rkt b/lazy/ext-ops/test/test-C-star-2-1.rkt index bbb2c8b..233466e 100644 --- a/lazy/ext-ops/test/test-C-star-2-1.rkt +++ b/lazy/ext-ops/test/test-C-star-2-1.rkt @@ -6,6 +6,10 @@ (tensor 7 8 9 10))) (b (tensor 2 3 4 5))) (check-ρ-∇ (d*-2-1 a b) + (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) + (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) + (tensor 10.0 12.0 14.0 16.0))) + (check-ρ-∇ (*-2-1 a b) (tensor (tensor 6 12 20 30) (tensor 14 24 36 50)) (list (tensor (tensor 2.0 3.0 4.0 5.0) (tensor 2.0 3.0 4.0 5.0)) (tensor 10.0 12.0 14.0 16.0)))) diff --git a/lazy/ext-ops/test/test-D-sum.rkt b/lazy/ext-ops/test/test-D-sum.rkt index e77ab0a..9271f11 100644 --- a/lazy/ext-ops/test/test-D-sum.rkt +++ b/lazy/ext-ops/test/test-D-sum.rkt @@ -5,6 +5,8 @@ (require (only-in "A-scalar-ops.ss" d-sqr d* d-)) (let ((a (tensor 3 4 5))) + (check-ρ-∇ (sum-1 a) 12 + (list (tensor 1.0 1.0 1.0))) (check-ρ-∇ (d-sum a) 12 (list (tensor 1.0 1.0 1.0)))) @@ -50,4 +52,15 @@ (list (tensor (tensor 14.0 16.0 18.0 20.0) (tensor 14.0 16.0 18.0 20.0)) (tensor (tensor 10.0 12.0 14.0 16.0) - (tensor 10.0 12.0 14.0 16.0)))))) + (tensor 10.0 12.0 14.0 16.0))))) + (let ((a (tensor (tensor 0 1 2) + (tensor 3 4 5) + (tensor 6 7 8)))) + (check-ρ-∇ (sum-cols-2 a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))) + (check-ρ-∇ (d-sum-cols a) (tensor 9 12 15) + (list (tensor (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0)))))) diff --git a/lazy/ext-ops/test/test-E-argmax.rkt b/lazy/ext-ops/test/test-E-argmax.rkt index f72819e..144980a 100644 --- a/lazy/ext-ops/test/test-E-argmax.rkt +++ b/lazy/ext-ops/test/test-E-argmax.rkt @@ -2,6 +2,8 @@ (require (only-in "../tensors.rkt" tensor)) (let ((y (tensor 0.0 0.0 1.0 0.0))) + (check-ρ-∇ (argmax-1 y) 2.0 + (list (tensor 0.0 0.0 0.0 0.0))) (check-ρ-∇ (d-argmax y) 2.0 (list (tensor 0.0 0.0 0.0 0.0)))) diff --git a/lazy/ext-ops/test/test-F-max.rkt b/lazy/ext-ops/test/test-F-max.rkt index 01ab1a5..88d74a5 100644 --- a/lazy/ext-ops/test/test-F-max.rkt +++ b/lazy/ext-ops/test/test-F-max.rkt @@ -2,6 +2,10 @@ (require rackunit) (require (only-in "../tensors.rkt" tensor)) + (let ((y (tensor 0.0 1.0 0.0 0.0))) + (check-ρ-∇ (max-1 y) 1.0 (list y)) + (check-ρ-∇ (d-max y) 1.0 (list y))) + (let ((y (tensor (tensor 0.0 0.0 1.0 0.0) (tensor 0.0 1.0 0.0 0.0) (tensor 1.0 0.0 0.0 0.0) diff --git a/lazy/ext-ops/test/test-K-concat.rkt b/lazy/ext-ops/test/test-K-concat.rkt index b427ced..aa405fc 100644 --- a/lazy/ext-ops/test/test-K-concat.rkt +++ b/lazy/ext-ops/test/test-K-concat.rkt @@ -8,6 +8,11 @@ (define r1-t2 (tensor 5.0 6.0 7.0)) (define r1-t1 (tensor 3.0 4.0 5.0 6.0 7.0)) + (check-ρ-∇ (concat-1-1 r1-t2 r1-t1) + (tensor 5.0 6.0 7.0 3.0 4.0 5.0 6.0 7.0) + (list (tensor 1.0 1.0 1.0) + (tensor 1.0 1.0 1.0 1.0 1.0))) + (check-dual-equal? (d-concat r2-t1 r1-t2) (tensor (tensor 3.0 4.0 5.0 6.0 7.0) diff --git a/lazy/tensors/c0-ast.rkt b/lazy/tensors/c0-ast.rkt index 1e2d9e1..ae03f89 100644 --- a/lazy/tensors/c0-ast.rkt +++ b/lazy/tensors/c0-ast.rkt @@ -37,6 +37,10 @@ tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) #:transparent) +(struct tcomp-prim1-ρ tcomp (f f-acc sign shape-fn tp) #:transparent) +(struct tcomp-prim1-∇ tcomp (f f-acc sign shape-fn tp zp) #:transparent) +(struct tcomp-prim2-ρ tcomp (f f-acc sign shape-fn tp-t tp-u) #:transparent) +(struct tcomp-prim2-∇ tcomp (f f-acc sign shape-fn tp-t tp-u zp out-ref0 out-ref1 i) #:transparent) (struct tcomp-reshape tcomp (s tp) #:transparent) (struct tcomp-let tcomp (lhs rhs body) #:transparent) (struct tcomp-var tcomp (name) #:transparent) @@ -180,6 +184,25 @@ (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) #"dsr" (number->bytes i))))) +(define gs-prim1-ρ + (λ (prim-sign tp) + (box (list #"p1r" (string->bytes prim-sign) (tpromise-sign tp))))) + +(define gs-prim2-ρ + (λ (signature tp-t tp-u) + (box (list #"p2r" (string->bytes signature) + (tpromise-sign tp-t) (tpromise-sign tp-u))))) + +(define gs-prim1-∇ + (λ (signature tp zp) + (box (list #"p1n" (string->bytes signature) (tpromise-sign tp) (tpromise-sign zp))))) + +(define gs-prim2-∇ + (λ (signature tp-t0 tp-t1 tp-z i) + (box (list #"p2n" (string->bytes signature) + (tpromise-sign tp-t0) (tpromise-sign tp-t1) (tpromise-sign tp-z) + #"dsr" (number->bytes i))))) + (define gs-reshape (λ (shape tp) (box (list* #"r" (tpromise-sign tp) (map number->bytes shape))))) @@ -278,6 +301,42 @@ (gdst-ext2-∇ tp-t0 tp-t1 tp-z) (gs-ext2-∇ prim-sign r0 r1 tp-t0 tp-t1 tp-z i))))) +(define tpmake-prim1-ρ + (λ (f f-acc prim-sign shape-fn tp) + (tpromise (tcomp-prim1-ρ f f-acc prim-sign shape-fn tp) + (shape-fn (tpromise-shape tp)) + (box (list (tpromise-dst tp))) + (gs-prim1-ρ prim-sign tp)))) + +(define tpmake-prim2-ρ + (λ (f f-acc prim-sign shape-fn tp-t tp-u) + (tpromise + (tcomp-prim2-ρ f f-acc prim-sign shape-fn tp-t tp-u) + (shape-fn (tpromise-shape tp-t) (tpromise-shape tp-u)) + (box (list (tpromise-dst tp-t) (tpromise-dst tp-u))) + (gs-prim2-ρ prim-sign tp-t tp-u)))) + +(define tpmake-prim1-∇ + (λ (f f-acc prim-sign shape-fn tp zp) + (let ((zp (ensure-tpromise zp))) + (tpromise + (tcomp-prim1-∇ f f-acc prim-sign shape-fn tp zp) + (tpromise-shape tp) + (box (list (tpromise-dst tp) (tpromise-dst zp))) + (gs-prim1-∇ prim-sign tp zp))))) + +(define tpmake-prim2-∇ + (λ (fᵈ fᵈ-acc prim-sign shape-fn tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (let ((tp-t0 (ensure-tpromise tp-t0)) + (tp-t1 (ensure-tpromise tp-t1)) + (tp-z (ensure-tpromise tp-z))) + (tpromise + (tcomp-prim2-∇ fᵈ fᵈ-acc prim-sign shape-fn + tp-t0 tp-t1 tp-z out-ref0 out-ref1 i) + (if (zero? i) (tpromise-shape tp-t0) (tpromise-shape tp-t1)) + (gdst-ext2-∇ tp-t0 tp-t1 tp-z) ;; dst constucted in the same way as ext2-∇ + (gs-prim2-∇ prim-sign tp-t0 tp-t1 tp-z i))))) + (define tpmake-reshape (λ (tp shape) (tpromise @@ -295,6 +354,10 @@ (struct-out tcomp-ext2-ρ) (struct-out tcomp-ext1-∇) (struct-out tcomp-ext2-∇) + (struct-out tcomp-prim1-ρ) + (struct-out tcomp-prim2-ρ) + (struct-out tcomp-prim1-∇) + (struct-out tcomp-prim2-∇) (struct-out tcomp-reshape) (struct-out tcomp-let) (struct-out tcomp-var) @@ -314,4 +377,8 @@ tpmake-ext2-ρ tpmake-ext1-∇ tpmake-ext2-∇ + tpmake-prim1-ρ + tpmake-prim2-ρ + tpmake-prim1-∇ + tpmake-prim2-∇ tpmake-reshape) diff --git a/lazy/tensors/c1-racket-runtime.rkt b/lazy/tensors/c1-racket-runtime.rkt index 9fad17f..70561b8 100644 --- a/lazy/tensors/c1-racket-runtime.rkt +++ b/lazy/tensors/c1-racket-runtime.rkt @@ -55,6 +55,33 @@ (when out-idx1 (data-segment-set! out-idx1 (scalarize (flat s1 g1 0)))))))))) +(define prim2-∇-forcer! + (λ (fᵈ fᵈ-acc fᵈ-sign shape-fn t0 t1 z out-idx0 out-idx1) + (let* ((in-shape-a (flat-shape t0)) + (in-size-a (size-of in-shape-a)) + (in-shape-b (flat-shape t1)) + (in-size-b (size-of in-shape-b)) + (out-shape (shape-fn in-shape-a in-shape-b)) + (out-size (size-of out-shape))) + (let ((g0 (new-vec in-size-a 0.0)) + (g1 (new-vec in-size-b 0.0))) + (cond + ((null? out-shape) + (let ((v-z (new-vec 1 z))) + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + v-z 0 1))) + (else + (fᵈ g0 g1 + (flat-store t0) (flat-offset t0) in-size-a + (flat-store t1) (flat-offset t1) in-size-b + (flat-store z) (flat-offset z) out-size))) + (when out-idx0 + (data-segment-set! out-idx0 (scalarize (flat in-shape-a g0 0)))) + (when out-idx1 + (data-segment-set! out-idx1 (scalarize (flat in-shape-b g1 0)))))))) + (define rt:trefs (λ (ft b) (cond @@ -79,4 +106,6 @@ (provide runtime flat? acc:build-tensor acc:list->tensor acc:tref rt:trefs (struct-out ext2-∇-result) set-ext2-∇-result-res! ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ - flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref) + flat flat-store flat-offset flat-ext1-ρ data-segment data-segment-ref + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + prim2-∇-forcer!) diff --git a/lazy/tensors/c2-interpreter.rkt b/lazy/tensors/c2-interpreter.rkt index ebd84c4..7bbba2c 100644 --- a/lazy/tensors/c2-interpreter.rkt +++ b/lazy/tensors/c2-interpreter.rkt @@ -6,7 +6,8 @@ set-ext2-∇-result-res! acc:tref rt:trefs ext2-∇-result-res ext2-∇-forcer! scalarize flat-ext1-∇ ensure-flat flat-ext2-ρ flat flat-store flat-offset flat-ext1-ρ data-segment - data-segment-ref)) + apply-flat-ρ-fn-1 apply-flat-ρ-fn-2 apply-flat-∇-fn-1 apply-flat-∇-fn-2 + data-segment-ref prim2-∇-forcer!)) (define interp-tcomp (λ (tc env) @@ -67,6 +68,36 @@ (scalarize (flat-ext1-ρ f f-acc m shape-fn f-sign (ensure-flat (interp-tpromise tp env))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (apply-flat-ρ-fn-1 f (interp-tpromise tp) shape-fn)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (apply-flat-ρ-fn-2 f (interp-tensor tp-t) (interp-tpromise tp-u) shape-fn)] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (apply-flat-∇-fn-1 f (interp-tpromise tp) (scalarize (interp-tpromise zp)) shape-fn)] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (interp-tpromise tp-t0 env)) + (t1-instrs (interp-tpromise tp-t1 env)) + (z-instrs (interp-tpromise tp-z env))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index)))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref (current-ds-ref-index))) + (current-ds-ref-index (add1 (current-ds-ref-index))))) + (let* ((out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1))) + (index (if (zero? i) out-idx0 out-idx1)) + (v (data-segment-ref index))) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! fᵈ f-sign shape-fn + t0-instrs t1-instrs + z-instrs out-idx0 out-idx1) + (data-segment-ref index)) + (else v))))] [(tcomp-reshape s tp) (let ([interp-tp (interp-tpromise tp env)]) (flat s (flat-store interp-tp) (flat-offset interp-tp)))] diff --git a/lazy/tensors/c3-compiler.rkt b/lazy/tensors/c3-compiler.rkt index 594eb37..15d171c 100644 --- a/lazy/tensors/c3-compiler.rkt +++ b/lazy/tensors/c3-compiler.rkt @@ -99,6 +99,31 @@ [(tcomp-ext1-ρ f f-acc sign m shape-fn tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-ext1-ρ f f-acc sign m shape-fn tp^) ref^))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) + (values (tcomp-prim1-ρ f f-acc sign shape-fn tp^) ref^))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let*-values (((tp-t^ ref^) (gdr-tpromise tp-t ref memo)) + ((tp-u^ ref^^) (gdr-tpromise tp-u ref^ memo))) + (values (tcomp-prim2-ρ f f-acc sign shape-fn tp-t^ tp-u^) ref^^))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let*-values (((tp^ ref^) (gdr-tpromise tp ref memo)) + ((zp^ ref^^) (gdr-tpromise zp ref^ memo))) + (values (tcomp-prim1-∇ f f-acc sign shape-fn tp^ zp^) ref^^))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((tp-t0^ ref^) (gdr-tpromise tp-t0 ref memo)) + ((tp-t1^ ref^^) (gdr-tpromise tp-t1 ref^ memo)) + ((tp-z^ ref^^^) (gdr-tpromise tp-z ref^^ memo))) + (cond + ((and (eqv? i 0) + (not (tcomp-ds-ref-index (ext2-∇-result-res out0)))) + (set-ext2-∇-result-res! out0 (tcomp-ds-ref ref^^^))) + ((and (eqv? i 1) + (not (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (set-ext2-∇-result-res! out1 (tcomp-ds-ref ref^^^)))) + (values (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0^ tp-t1^ tp-z^ + out0 out1 i) + (add1 ref^^^)))] [(tcomp-reshape s tp) (let-values (((tp^ ref^) (gdr-tpromise tp ref memo))) (values (tcomp-reshape s tp^) ref^))] @@ -188,6 +213,18 @@ (cr-tpromise tp counter^ uid^)] [(tcomp-ext1-ρ f _ sign m shape-fn tp) (cr-tpromise tp counter^ uid^)] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (cr-tpromise tp counter^ uid^)] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let-values (((counter-1 uid-1) (cr-tpromise tp-t counter^ uid^))) + (cr-tpromise tp-u counter-1 uid-1))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let-values (((counter-1 uid-1) (cr-tpromise tp counter^ uid^))) + (cr-tpromise zp counter-1 uid-1))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let*-values (((counter-1 uid-1) (cr-tpromise tp-t0 counter^ uid^)) + ((counter-2 uid-2) (cr-tpromise tp-z counter-1 uid-1))) + (cr-tpromise tp-t1 counter-2 uid-2))] [(tcomp-reshape s tp) (cr-tpromise tp counter^ uid^)] [(tcomp-ds-ref index) (values counter^ uid^)] @@ -316,6 +353,46 @@ (ecs-tpromise tp counter) (λ (instrs) (inj-ecs-tcomp (tcomp-ext1-ρ f f-acc sign m shape-fn instrs) tc-counter-data)))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (->ecs + (ecs-tpromise tp counter) + (λ (instrs) + (inj-ecs-tcomp (tcomp-prim1-ρ f f-acc sign shape-fn instrs) tc-counter-data)))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (->ecs + (ecs-tpromise tp-t counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise tp-u counter) + (λ (u-instrs) + (inj-ecs-tcomp + (tcomp-prim2-ρ f f-acc sign shape-fn t-instrs u-instrs) + tc-counter-data)))))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (->ecs + (ecs-tpromise tp counter) + (λ (t-instrs) + (->ecs + (ecs-tpromise zp counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim1-∇ f f-acc sign shape-fn t-instrs z-instrs) + tc-counter-data)))))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (->ecs + (ecs-tpromise tp-t0 counter) + (λ (t0-instrs) + (->ecs + (ecs-tpromise tp-t1 counter) + (λ (t1-instrs) + (->ecs + (ecs-tpromise tp-z counter) + (λ (z-instrs) + (inj-ecs-tcomp + (tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn + t0-instrs t1-instrs z-instrs + out0 out1 i) + tc-counter-data)))))))] [(tcomp-reshape s tp) (->ecs (ecs-tpromise tp counter) @@ -438,6 +515,33 @@ `(scalarize (flat-ext1-ρ ,f ,f-acc ,m ,shape-fn ,sign (ensure-flat ,instrs))))] + [(tcomp-prim1-ρ f f-acc sign shape-fn tp) + (let ((instrs (gr-tpromise tp))) + `(apply-flat-ρ-fn-1 ,f ,instrs ,shape-fn))] + [(tcomp-prim2-ρ f f-acc sign shape-fn tp-t tp-u) + (let ((t-instrs (gr-tpromise tp-t)) + (u-instrs (gr-tpromise tp-u))) + `(apply-flat-ρ-fn-2 ,f ,t-instrs ,u-instrs ,shape-fn))] + [(tcomp-prim1-∇ f f-acc sign shape-fn tp zp) + (let ((t-instrs (gr-tpromise tp)) + (z-instrs (gr-tpromise zp))) + `(apply-flat-∇-fn-1 ,f ,t-instrs (scalarize ,z-instrs) ,shape-fn))] + [(tcomp-prim2-∇ fᵈ fᵈ-acc f-sign shape-fn tp-t0 tp-t1 tp-z out0 out1 i) + (let ((t0-instrs (gr-tpromise tp-t0)) + (t1-instrs (gr-tpromise tp-t1)) + (z-instrs (gr-tpromise tp-z)) + (out-idx0 (tcomp-ds-ref-index (ext2-∇-result-res out0))) + (out-idx1 (tcomp-ds-ref-index (ext2-∇-result-res out1)))) + (let ((index (if (zero? i) out-idx0 out-idx1))) + `(let* ([index ,index] + [v (data-segment-ref index)]) + (cond + ((eqv? v 'uncalculated) + (prim2-∇-forcer! ,fᵈ ,fᵈ-acc ,f-sign ,shape-fn + ,t0-instrs ,t1-instrs + ,z-instrs ,out-idx0 ,out-idx1) + (data-segment-ref index)) + (else v)))))] [(tcomp-reshape s tp) (let ((instrs (gr-tpromise tp))) `(flat ',s diff --git a/lazy/tensors/test/test-c3-compiler.rkt b/lazy/tensors/test/test-c3-compiler.rkt index e7be52a..6bae788 100644 --- a/lazy/tensors/test/test-c3-compiler.rkt +++ b/lazy/tensors/test/test-c3-compiler.rkt @@ -3,6 +3,8 @@ (require "B-test-programs.rkt") (require "0-lazy.rkt") (require "c2-interpreter.rkt") + (require (prefix-in acc: (only-in "../../accelerated-tensors/autodiff.rkt" + make-printable))) (require (prefix-in acc: "../../accelerated-tensors/tensors.rkt")) (define current-test-program-name (make-parameter #f)) @@ -14,7 +16,7 @@ (('data-segment ds) ('signature signature) ('input-computation (tpromise-tensor tp)) - ('expected-interpretation interp-tp) + ('expected-interpretation (acc:make-printable interp-tp)) ('test-name (current-test-program-name))) (for ((d ds)) (unless (or (number? d) @@ -44,7 +46,7 @@ " extract-common-subexpression doesn't" " match expected interpretation. Actual " "interpretation: ~a~n") - interp-extracted))) + (acc:make-printable interp-extracted)))) (let* ((gr (generate-racket extracted)) (rkt (compile-racket gr)) (interp-rkt (interp-racket rkt ds))) @@ -54,7 +56,7 @@ "Result of interpreting compiled racket code doesn't" " match expected interpretation. Actual " "interpretation: ~a~n") - interp-rkt))) + (acc:make-printable interp-rkt)))) (hash-set! (cache) signature rkt) (compile-tensor tp) (unless (eqv? (hash-count (cache)) 1)