Skip to content

Commit

Permalink
[add-lazy]Switch compiled tensor runtime to acc tensor impl
Browse files Browse the repository at this point in the history
  • Loading branch information
DarshalShetty committed Jul 17, 2024
1 parent b656785 commit 0930df7
Show file tree
Hide file tree
Showing 25 changed files with 908 additions and 438 deletions.
5 changes: 3 additions & 2 deletions lazy/autodiff/A-autodiff.rkt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#lang racket

(require string-interpolation)
(require "../tensors.rkt")

;;----------------------------
Expand Down Expand Up @@ -52,7 +53,7 @@
(hash-set σ d (+-ρ z g)))))

(define +-ρ
(ext2-ρ + 0 0))
(ext2-ρ + (λ (a b) "@{a} + @{b}") 0 0))

;;----------------------------
;; Reverse-mode AD
Expand Down Expand Up @@ -111,7 +112,7 @@
((dual? v) (trace-print (ρ v) port))
(else (fprintf port "~a~%" v)))))

(define (one-like s) ((ext1-ρ (λ (x) 1.0) 0) s))
(define (one-like s) ((ext1-ρ (λ (x) 1.0) (λ (x) "1.0") 0) s))

(include "test/test-A-autodiff.rkt")

Expand Down
89 changes: 74 additions & 15 deletions lazy/autodiff/B-prims.rkt
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
#lang racket

(require (only-in "../../accelerated-tensors/ext-impl.rkt"
new-vec
apply-flat-ρ-fn-1
apply-flat-ρ-fn-2
apply-flat-∇-fn-1
apply-flat-∇-fn-2))
(require "../tensors.rkt")
(require "A-autodiff.ss")

(struct prim (ρ-fn -fn shape-fn signature expects-prealloc? proc)
(struct prim (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape-fn signature expects-prealloc? proc)
#:property prop:procedure (λ (this . args)
(apply (prim-proc this) args)))

;;TODO: Add new ast nodes for the 4 forces being done in the four preallocated->functional-* functions
(define prim1
(let ((id 0))
(λ (ρ-fn -fn [shape (λ (l . r) l)] [expects-prealloc? #f])
(λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f])
(let ((prim-sign (string-append "p1" (~r id #:base 16))))
(set! id (add1 id))
(prim ρ-fn -fn shape prim-sign expects-prealloc?
(prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc?
(λ (da)
(prim1-dual ρ-fn ∇-fn da)))))))
(prim1-dual (if #;#f expects-prealloc? (preallocated->functional-1-ρ ρ-fn shape) ρ-fn)
(if #;#f expects-prealloc? (preallocated->functional-1-∇ ∇-fn shape) ∇-fn)
da)))))))

;; TODO: Convert the use of force* into the construction of an AST so that we
;; don't prematurely trigger computation.
(define prim1-dual
(λ (ρ-fn ∇-fn da)
(let ((ra (ρ da)))
Expand All @@ -27,12 +38,14 @@

(define prim2
(let ((id 0))
(λ (ρ-fn -fn [shape (λ (l . r) l)] [expects-prealloc? #f])
(λ (ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn [shape (λ (l . r) l)] [expects-prealloc? #f])
(let ((prim-sign (string-append "p2" (~r id #:base 16))))
(set! id (add1 id))
(prim ρ-fn -fn shape prim-sign expects-prealloc?
(prim ρ-fn ρ-acc-fn ∇-fn ∇-acc-fn shape prim-sign expects-prealloc?
(λ (da db)
(prim2-dual ρ-fn ∇-fn da db)))))))
(prim2-dual (if expects-prealloc? (preallocated->functional-2-ρ ρ-fn shape) ρ-fn)
(if expects-prealloc? (preallocated->functional-2-∇ ∇-fn shape) ∇-fn)
da db)))))))

(define prim2-dual
(λ (ρ-fn ∇-fn da db)
Expand All @@ -45,6 +58,48 @@
(let ((σ-hat ((κ da) da ga σ)))
((κ db) db gb σ-hat)))))))))

;;----------------------------
;; Managing flat-optimized and
;; non-flat ρ and ∇ functions
;;----------------------------

(define preallocated->functional-1-ρ
(λ (ρ-fn shape-fn)
(λ (ra)
(force*1 ra
(λ (ra)
(apply-flat-ρ-fn-1 ρ-fn ra shape-fn))))))

(define preallocated->functional-1-∇
(λ (∇-fn shape-fn)
(λ (ra z)
(force*2
(λ ()
(values ra z))
(λ (ra z)
(apply-flat-∇-fn-1 ∇-fn ra z shape-fn))))))

(define preallocated->functional-2-ρ
(λ (ρ-fn shape-fn)
(λ (ra rb)
(force*2
(λ ()
(values ra rb))
(λ (ra rb)
(apply-flat-ρ-fn-2 ρ-fn ra rb shape-fn))))))

(define preallocated->functional-2-∇
(λ (∇-fn shape-fn)
(λ (ra rb z)
(force*2
(λ ()
(values ra rb))
(λ (ra rb)
(force*1
z
(λ (z)
(apply-flat-∇-fn-2 ∇-fn ra rb z shape-fn))))))))

;;----------------------------
;; Dualized tensor op creators
;;----------------------------
Expand All @@ -53,10 +108,12 @@
(unless (prim? f)
(error 'ext1-prim "Function to be extended must be a primitive. Found: ~a" f))
(prim1
(ext1-ρ (prim-ρ-fn f) n (prim-shape-fn f)
(prim-expects-prealloc? f) (prim-signature f))
(ext1-∇ (prim-∇-fn f) n (prim-shape-fn f)
(prim-expects-prealloc? f) (prim-signature f))
(ext1-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) n (prim-shape-fn f)
(prim-expects-prealloc? f) (string-append "r" (prim-signature f)))
(prim-ρ-acc-fn f)
(ext1-∇ (prim-∇-fn f) (prim-∇-acc-fn f) n (prim-shape-fn f)
(prim-expects-prealloc? f) (string-append "n" (prim-signature f)))
(prim-∇-acc-fn f)
(prim-shape-fn f)
#f)))

Expand All @@ -65,10 +122,12 @@
(unless (prim? f)
(error 'ext2-prim "Function to be extended must be a primitive. Found: ~a" f))
(prim2
(ext2-ρ (prim-ρ-fn f) m n (prim-shape-fn f)
(prim-expects-prealloc? f) (prim-signature f))
(ext2-∇ (prim-∇-fn f) m n (prim-shape-fn f)
(prim-expects-prealloc? f) (prim-signature f))
(ext2-ρ (prim-ρ-fn f) (prim-ρ-acc-fn f) m n (prim-shape-fn f)
(prim-expects-prealloc? f) (string-append "r" (prim-signature f)))
(prim-ρ-acc-fn f)
(ext2-∇ (prim-∇-fn f) (prim-∇-acc-fn f) m n (prim-shape-fn f)
(prim-expects-prealloc? f) (string-append "n" (prim-signature f)))
(prim-∇-acc-fn f)
(prim-shape-fn f)
#f)))

Expand Down
2 changes: 1 addition & 1 deletion lazy/autodiff/E-print.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
(require "A-autodiff.rkt")
(require "../tensors/0-lazy.rkt")
(require "../tensors/1-reflect.rkt")
(require (except-in "../../flat-tensors/ext-impl.rkt" scalarize))
(require (except-in "../../accelerated-tensors/ext-impl.rkt" scalarize))

(define max-tensor-print-length (make-parameter 5))

Expand Down
50 changes: 25 additions & 25 deletions lazy/autodiff/test/test-E-print.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,54 @@
deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor
deep-tensor deep-tensor deep-tensor deep-tensor deep-tensor))

(check-equal? (make-printable long-tensor 3) (fake-tensor '(1 2 3 ...)))
(check-equal? (make-printable long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...)))
(check-equal? (make-printable deep-tensor 3)
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...)))

(check-equal? (make-printable deeper-tensor 3)
(fake-tensor
(list
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
'...)))
(parameterize ((max-tensor-print-length 3))
(check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1 2 3 ...)))
(check-equal? (make-printable dualized-long-tensor 3) (fake-tensor '(1.0 2.0 3.0 ...)))
(check-equal? (make-printable (list long-tensor dualized-long-tensor deeper-tensor))
(list
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor
(list
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
(fake-tensor
(list (fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(fake-tensor '(1 2 3 ...))
(list (fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
(fake-tensor '(1.0 2.0 3.0 ...))
'...))
'...))))))
Loading

0 comments on commit 0930df7

Please sign in to comment.