File scheme/cps-conversion.scm from the latest check-in


;; CPS変換を丁寧に考えてみる
;; できるだけ余計な難しさがない最小限のサンプルを書く。
;; sectionの前後に(exit)などを挟んで評価結果を見ていくとよい。

;; CPS変換にGaucheのmatchやgensymを使用。
;; 結果を見やすくするためにr7rsのshowを使う
(use util.match)
(use scheme.show)



;; 結果を見やすくする用
(define (section label)
  (format #t "===== ~A =====~%" label))



(section 'ret)

;; 普通の関数は呼ばれたところに値を返す

(define (inc-normal n) (+ n 1))

(print (inc-normal 0)) ;; => 1

;; その返す先を与えてやる、というのがCPSなので、継続というよりリターンとまず呼んでみる。
;; CPSでの関数は第一引数で必ずリターン先(継続)を受け取るとする。

(define (inc-cps ret n) (ret (+ n 1)))

(inc-cps print 0) ;; => 1



(section 'nest)

;; 入れ子になっている関数呼び出しの場合は、ちょっと工夫がいる。
;; わかりやすくするために別の関数を定義しておく。

(define (dec-normal n) (- n 1))
(define (dec-cps ret n) (ret (- n 1)))

;; まず普通の関数では単純に入れ子にすればいい。
;; 0に1足して1引いた結果を出力

(print (dec-normal (inc-normal 0))) ;; => 0

;; (inc-cps ret 0)のretには、返り値の1が「その後どうなって欲しいか」を書く。
;; 1が次はdec-cpsされて欲しいので、inc-cpsに渡すリターン先は(lambda (n) (dec-cps ?? n))となる。
;; 更にその??には、外側のretが入る。だから上の式は次のようになる。

(inc-cps (lambda (n) (dec-cps print n)) 0) ;; => 0

;; 複数の引数がある場合は、同様にlambdaを重ねていく。



(section 'if)

;; if式の場合を考える。
;; まず普通の関数内で使う場合

(define (check42-normal n)
  (if (= n 42) 'yes 'no))

(print (check42-normal 42)) ;; => yes
(print (check42-normal 43)) ;; => no

;; CPS変換中の場合、全ての関数(手続き)がretを受け取るので、
;; (if test then else)のtestのリターン先としてif以降を渡す。

;; (= a b)も(cps= ret a b)としてリターン先を受け取るように定義しておく。

(define (cps= ret a b) (ret (= a b)))

;; cps=の真偽値が、(if test then else)のtestに来るようにしたい。
;; (cps= (lambda (真偽値) (if 真偽値 (ret then) (ret else))) a b)となれば良い。

(define (check42-cps ret n)
  (cps= (lambda (test) (if test (ret 'yes) (ret 'no)))
        n
        42))

(check42-cps print 42) ;; => yes
(check42-cps print 43) ;; => no



(section 'set!)

;; set!式
;; (set! var-name (expr))について、(expr)の結果がset!の値としてリターンして欲しい。
;; つまりexprには(lambda (result) (set! var-name result))というリターンが渡される。
;; Schemeの仕様ではset!の結果は不定だけど、ここではresultが返ることにする。

;; 呼ばれる度にカウントアップする関数を定義してみる。

(define count-up-normal
  (let ([count 0])
    (lambda ()
      (set! count (inc-normal count))
      count)))

(print (count-up-normal)) ;; => 1
(print (count-up-normal)) ;; => 2

(define count-up-cps
  (let ([count 0])
    (lambda (ret)
      (inc-cps (lambda (n)
                 (set! count n)
                 (ret n))
               count))))

(count-up-cps print) ;; => 1
(count-up-cps print) ;; => 2



(section 'begin)

;; begin式
;; 最後の式以外の結果は捨てられて良い。

(define (foo-normal)
  (print 'foo)
  'done-foo)

(define (bar-normal)
  (print 'bar)
  'done-bar)

(define (foo-bar-normal)
  (foo-normal)
  (bar-normal))

(print (foo-bar-normal)) ;; => foo foo done-barとprint

;; CPSにおいては、fooには「結果を受け取ってそれを無視して次」というリターン先を渡す。
;; つまり(lambda (ignore) (bar-cps ret)))となる。

(define (foo-cps ret)
  (print 'foo)
  (ret 'done-foo))

(define (bar-cps ret)
  (print 'bar)
  (ret 'done-bar))

(define (foo-bar-cps ret)
  (foo-cps (lambda (_) (bar-cps ret))))

(foo-bar-cps print)



(section 'lambda)

;; lambda式の内部の変換は、既に出ている。
;; 暗黙のbeginが無いと仮定して、(lambda (a b ...) (expr a b ...)) は
;; (lambda (ret a b ...) (expr ret a b ...)) と変換する。
;; つまり
;;     1. 第一引数にリターン先を追加する
;;     2. そのリターン先で(lambda args expr)のexprをCPS変換する
;; これは最初の例のdefineで既にやっている。
;; 復習としてもう一度lambda式として書いてみる。
;; ついでにinc-normalとinc-cpsを使って、lambdaのexprがCPS変換済みの場合を見てみる。

(define inc-normal-again
  (lambda (n)
    (inc-normal n)))

(print (inc-normal-again 0)) ;; => 1

(define inc-cps-again
  (lambda (ret n)
    (inc-cps ret n)))

(inc-cps-again print 0) ;; => 1

;; 外部の変換、つまりlambda式をどう扱うかについては、単純に現在のリターン先へ
;; lambdaを渡してやればいい。

;;TODO サンプル



(section 'call/cc)

;; (call/cc f)という式について
;; まず(f ret k)のretは何か?
;; fに渡されたリターン先kが使われなかった場合、その後の計算全てを無視して
;; どこかに返る必要がある。
;;
;; これはどこなのかcall/ccからはわからないんだけど、少なくとも
;; (lambda (x) x)をretとして渡せばどこかには返る。
;; ちなみに下のcps-conversionではevalに返る。
;;
;; 次にkについて。これは現在の継続なのでcall/cc時のretではあるんだけど、
;; CPS変換中の呼び出しは(k ret args...)となるので、(lambda args (apply k args))としてやる。

;;TODO retを無視したり保存するサンプル



(section 'cps-conversion)

;; 実際にCPS変換器を作ってみる。

;; とりあえずpair以外はatomとする。
(define (atom? x) (not (pair? x)))

(define (cps-conversion expr ret)
  ;; リターン先生成時の値用シンボル生成手続き。
  (define genval
    (let ([id 0])
      (lambda ()
        (set! id (+ id 1))
        (string->symbol (format #f "$v~S" id)))))

  ;; cps-conversionは再帰呼出しの時長いのでnamed letを使う
  (let rec ([expr expr] [ret ret])
    (match expr
      ;; atomの場合は、リターン先に値を返すだけ
      [(? atom?) `(,ret ,expr)]

      ;; if
      [('if test then else)
       (let ([r2 (genval)]) ;; test式の結果
         (rec test
              `(lambda (,r2)
                 ;; test式の結果によって、ifに渡されたretでCPS変換済みのthen/elseに任せる
                 (if ,r2 ,(rec then ret) ,(rec else ret)))))]

      ;; begin
      ;; 最後の式だけ特別扱いする必要があり、また後ろのリターン先(継続)から
      ;; 徐々に作っていく必要があるので、まずbodyをreverseしてから変換する。
      [('begin . body)
       (let* ([body (reverse body)]
              [last (rec (car body) ret)] ;; 最後の式はbeginへのリターン先を使う
              [rest (cdr body)])
         (fold (lambda (x ret)
                 (let ([r2 (genval)])
                   ;; beginのlast以外については、結果を捨てる
                   (rec x `(lambda (,r2) ,ret))))
               last rest))]

      ;; set!
      ;; retにvalを返す仕様にする
      [('set! var val)
       (let ([v (genval)])
         (rec val `(lambda (,v) (set! ,var ,v) (,ret ,v) )))]

      ;; lambda
      ;; argsの先頭にリターン先受け取りを追加して、そのリターン先をもとに内部をCPS変換。
      ;; 最後に現在のリターン先にlambda式を渡す。
      [('lambda args expr)
       (let* ([v (genval)] ;; 内部のリターン先
              [args (cons v args)])
         `(,ret (lambda ,args ,(rec expr v))))]

      ;; quote
      ;; そのままretにクォート式を渡す。
      [('quote x) `(,ret ',x)]

      ;; call/cc
      [('call/cc f)
       (let* (;; (f ret k)に渡されるkは、CPS内部で呼ばれるので(k ret . args)のlambdaに定義
              ;; 継続kが呼ばれたときのretに、call/cc時のretの結果が入ることに注意。
              ;; これはちょっと難しい。保存した継続が後から呼ばれた時、その継続を
              ;; 呼んだ元に返らないといけない。そうじゃないと、保存した継続自体が
              ;; その後の継続を断ち切ってしまう。
              [kret (genval)]
              [k   `(lambda (,kret . args) (,kret (apply ,ret args)))]
              ;; (f ret k)のretは、CPS変換されたコードをevalしている評価器自体に
              ;; どこで継続を切るか任せる。
              [ret `(lambda (x) x)]
              [vf  (genval)])
         (rec f `(lambda (,vf) (,vf ,ret ,k))))]

      ;; primitives
      ;; 変換後のコードをevalしたいので、いくつかの手続きはCPS対応しておく
      [('= . args) (rec `(cps= ,@args) ret)]
      [('> . args) (rec `(cps> ,@args) ret)]
      [('+ . args) (rec `(cps+ ,@args) ret)]
      [('prn . args) (rec `(cps-prn ,@args) ret)]

      ;; call
      ;; 引数をv1, v2...と集めていって、最後に関数fの結果vfをもとに(vf ret v1 v2 ...)という呼び出しにする。
      ;; 最後の関数適用で各引数に対応する変数がわかってないといけないので、まず最初にmapで作る。
      ;; ((lambda (v1) ((lambda (vf) (vf ret v1)) ef)) e1)
      [(f . args)
       (let* ([vars (map (lambda (_) (genval)) args)]
              [vf   (genval)]
              [last (rec f `(lambda (,vf) (,vf ,ret ,@vars)))])
         (fold (lambda (x var ret)
                 (rec x `(lambda (,var) ,ret)))
               last
               args
               vars))]
      )))

;; eval用関数定義
;; -----

(define (cps> ret a b) (ret (> a b)))
(define (cps+ ret a b) (ret (+ a b)))

;; 最後にScheme側に値を返すリターン先(継続)
(define (identity x) x)

;; 内部でのprint
(define (cps-prn ret . args)
  (for-each print args)
  (ret (car args)))

;; 結果を見やすく
;; -----

(define (cpsconv expr)
  (print "## SOURCE")
  (show #t (pretty expr))
  (print "## CPS")
  (let* ([converted (cps-conversion expr 'identity)])
    (show #t (pretty converted))
    (print "## EVAL")
    (let* ([val (eval converted (current-module))])
      (show #t (pretty val)))))

;; テスト

(cpsconv '(if (= 42 43) 'what? 'ok))
(cpsconv '((lambda (x) x) 42))
(cpsconv '(begin (prn 'foo) (prn 'bar) 'done))



(section 'cps-call/cc)

;; call/ccの継続にそのまま返す、つまり何もしない
(cpsconv '(+ (call/cc (lambda (k) (k 41))) 1))

;; 残りの継続(この場合は43の評価)を無視
(cpsconv '(begin (call/cc (lambda (k) 42)) 43))

;; 継続を保存して後から起動。
(define *k*)
(cpsconv '(begin
            (call/cc (lambda (k) (begin (set! *k* k) 42)))
            43))
(print (*k* identity '())) ;; call/ccがあった場所に入るべき適当な値を渡してやる

;; 何度も起動してカウントアップしてみる
(define *count* 0)
(cpsconv
 '(begin
    (set! *count* (+ *count*
                     (call/cc (lambda (k) (begin (set! *k* k) *count*)))))
    *count*))
;; => call/ccが*count*を返すので最初は0
(print (*k* identity 1)) ;; => 0+1 = 1
(print (*k* identity 2)) ;; => 1+2 = 3
(print (*k* identity 3)) ;; => 3+3 = 6