DLVM DLVM A Modern Compiler Framework for Neural Network DSLs DLVM - - PowerPoint PPT Presentation

dlvm dlvm
SMART_READER_LITE
LIVE PREVIEW

DLVM DLVM A Modern Compiler Framework for Neural Network DSLs DLVM - - PowerPoint PPT Presentation

DLVM DLVM A Modern Compiler Framework for Neural Network DSLs DLVM A Modern Compiler Framework for Neural Network DSLs Richard Wei Lane Schwartz Vikram Adve University of Illinois at Urbana-Champaign 1-2 years ago 1-2 years ago Deep


slide-1
SLIDE 1

DLVM

slide-2
SLIDE 2

DLVM

A Modern Compiler Framework for Neural Network DSLs

slide-3
SLIDE 3

DLVM

Richard Wei Lane Schwartz Vikram Adve University of Illinois at Urbana-Champaign

A Modern Compiler Framework for Neural Network DSLs

slide-4
SLIDE 4

1-2 years ago

slide-5
SLIDE 5

1-2 years ago

Deep Learning Frameworks

slide-6
SLIDE 6

1-2 years ago

Deep Learning Frameworks Compiler Technologies

slide-7
SLIDE 7

Today

Deep Learning Frameworks Compiler Technologies

slide-8
SLIDE 8

Today

Deep Learning Frameworks Compiler Technologies

Deep Learning Compiler Technologies

slide-9
SLIDE 9

Today

  • Latte.jl
  • XLA
  • NNVM + TVM
  • ONNX
  • PyTorch JIT
  • DLVM

Deep Learning Compiler Technologies

slide-10
SLIDE 10

Neural networks are programs

slide-11
SLIDE 11

Neural networks are programs

Control Flow Auto Vectorization Intermediate Representation Automatic Differentiation Static Analysis Optimizations Compute Typing

slide-12
SLIDE 12

A New Compiler Problem

slide-13
SLIDE 13

A New Compiler Problem

  • Programs, not just a data flow graph
slide-14
SLIDE 14

A New Compiler Problem

  • Programs, not just a data flow graph
  • Type safety
slide-15
SLIDE 15

A New Compiler Problem

  • Programs, not just a data flow graph
  • Type safety
  • Ahead-of-time AD
slide-16
SLIDE 16

A New Compiler Problem

  • Programs, not just a data flow graph
  • Type safety
  • Ahead-of-time AD
  • Code generation
slide-17
SLIDE 17

A New Compiler Problem

  • Programs, not just a data flow graph
  • Type safety
  • Ahead-of-time AD
  • Code generation
  • Lightweight installation
slide-18
SLIDE 18

C/C++ Python

DSL Interpreter CodeGen Graph Optimizer Libraries

slide-19
SLIDE 19

Python

slide-20
SLIDE 20

Python

slide-21
SLIDE 21

Safe language

slide-22
SLIDE 22

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-23
SLIDE 23
  • NN as a host language function

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-24
SLIDE 24
  • NN as a host language function
  • Type safety

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-25
SLIDE 25
  • NN as a host language function
  • Type safety
  • Naturalness

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-26
SLIDE 26
  • NN as a host language function
  • Type safety
  • Naturalness
  • Lightweight modular staging*

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-27
SLIDE 27
  • NN as a host language function
  • Type safety
  • Naturalness
  • Lightweight modular staging*
  • Compiler magic

Safe language

DSL

* Rompf, Tiark, and Martin Odersky. Lightweight Modular Staging: A Pragmatic Approach to Runtime Code Generation and Compiled DSLs, 2010

slide-28
SLIDE 28

Safe language

Libraries DSL

slide-29
SLIDE 29

Safe language

Libraries DSL

  • Trainer
slide-30
SLIDE 30

Safe language

Libraries DSL

  • Trainer
  • Layers
slide-31
SLIDE 31

Safe language

Libraries DSL

  • Trainer
  • Layers
  • Application API
slide-32
SLIDE 32

Safe language

Libraries DSL Compiler Infrastructure

slide-33
SLIDE 33

Safe language

Libraries DSL Compiler Infrastructure

  • Generic linear algebra IR
slide-34
SLIDE 34

Safe language

Libraries DSL Compiler Infrastructure

  • Generic linear algebra IR
  • Automatic differentiation
slide-35
SLIDE 35

Safe language

Libraries DSL Compiler Infrastructure

  • Generic linear algebra IR
  • Automatic differentiation
  • Optimizations
slide-36
SLIDE 36

Safe language

Libraries DSL Compiler Infrastructure

  • Generic linear algebra IR
  • Automatic differentiation
  • Optimizations
  • Code generation
slide-37
SLIDE 37

Safe language

Libraries DSL Compiler Infrastructure

  • Generic linear algebra IR
  • Automatic differentiation
  • Optimizations
  • Code generation
  • Runtime
slide-38
SLIDE 38

DSL Compiler Infrastructure

slide-39
SLIDE 39

DLVM

slide-40
SLIDE 40
  • Linear algebra IR

DLVM

slide-41
SLIDE 41
  • Linear algebra IR
  • Framework for building DSLs

DLVM

slide-42
SLIDE 42
  • Linear algebra IR
  • Framework for building DSLs
  • Automatic backpropagator

DLVM

slide-43
SLIDE 43
  • Linear algebra IR
  • Framework for building DSLs
  • Automatic backpropagator
  • Multi-stage optimizer

DLVM

slide-44
SLIDE 44
  • Linear algebra IR
  • Framework for building DSLs
  • Automatic backpropagator
  • Multi-stage optimizer
  • Static code generator based on LLVM

DLVM

slide-45
SLIDE 45
slide-46
SLIDE 46

Intermediate Representation Analyses Verifier Transforms CoreCompute CoreOp CoreTensor DLVM Core

slide-47
SLIDE 47

TEL (Standalone DSL) NNKit (Embedded DSL) DSL stack Intermediate Representation Analyses Verifier Transforms CoreCompute CoreOp CoreTensor DLVM Core

slide-48
SLIDE 48

LLVM IR Generator DLRuntime TEL (Standalone DSL) NNKit (Embedded DSL) DSL stack Intermediate Representation Analyses Verifier Transforms CoreCompute CoreOp CoreTensor DLVM Core

slide-49
SLIDE 49

LLVM IR Generator DLRuntime Command Line Toolchain

dlc, dlopt

TEL (Standalone DSL) NNKit (Embedded DSL) DSL stack Intermediate Representation Analyses Verifier Transforms CoreCompute CoreOp CoreTensor DLVM Core

slide-50
SLIDE 50

LLVM IR Generator DLRuntime

dlc, dlopt

Intermediate Representation Analyses Verifier Transforms CoreCompute CoreOp CoreTensor DLVM Core LLVM Compiler Infrastructure GPU CPU

slide-51
SLIDE 51

Core Language: DLVM IR

slide-52
SLIDE 52

Tensor Type

Rank Notation Descripton

i64

64-bit integer

1

<100 x f32>

float vector of size 100

2

<100 x 300 x f64>

double matrix of size 100x300

n

<100 x 300 x ... x bool>

rank-n tensor

First-class tensors

slide-53
SLIDE 53

Domain-Specific Instructions

Kind Example Element-wise unary

tanh %a: <10 x f32>

Element-wise binary

power %a: <10 x f32>, %b: 2: f32

Dot

dot %a: <10 x 20 x f32>, %b: <20 x 2 x f32>

Concatenate

concatenate %a: <10 x f32>, %b: <20 x f32> along 0

Reduce

reduce %a: <10 x 30 x f32> by add along 1

Transpose

transpose %m: <2 x 3 x 4 x 5 x i32>

Convolution

convolve %a: <…> kernel %b: <…> stride %c: <…> …

Slice

slice %a: <10 x 20 x i32> from 1 upto 5

Random

random 768 x 10 from 0.0: f32 upto 1.0: f32

Select

select %x: <10 x f64>, %y: <10 x f64> by %flags: <10 x bool>

Compare

greaterThan %a: <10 x 20 x bool>, %b: <1 x 20 x bool>

Data type cast

dataTypeCast %x: <10 x i32> to f64

slide-54
SLIDE 54

General-Purpose Instructions

Kind Example Function application

apply %foo(%x: f32, %y: f32): (f32, f32) -> <10 x 10 x f32>

Branch

branch 'block_name(%a: i32, %b: i32)

Conditional (if-then-else)

conditional %cond: bool then 'then_block() else 'else_block()

Shape cast

shapeCast %a: <1 x 40 x f32> to 2 x 20

Extract

extract #x from %pt: $Point

Insert

insert 10: f32 to %pt: $Point at #x

Allocate stack

allocateStack $Point count 1

Allocate heap

allocateHeap $MNIST count 1

Deallocate

deallocate %x: *<10 x f32>

Load

load %ptr: *<10 x i32>

Store

store %x: <10 x i32> to %ptr: *<10 x i32>

Copy

copy from %src: *<10 x f16> to %dst: *<10 x f16> count 1: i64

slide-55
SLIDE 55

Instruction Set

slide-56
SLIDE 56

Instruction Set

  • Primitive math operators & general purpose operators
slide-57
SLIDE 57

Instruction Set

  • Primitive math operators & general purpose operators
  • No softmax, sigmoid
  • Composed by primitive math ops
slide-58
SLIDE 58

Instruction Set

  • Primitive math operators & general purpose operators
  • No softmax, sigmoid
  • Composed by primitive math ops
  • No min, max, relu
  • Composed by compare & select
slide-59
SLIDE 59

DLVM IR

slide-60
SLIDE 60

DLVM IR

  • Full static single assignment (SSA) form
slide-61
SLIDE 61

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
slide-62
SLIDE 62

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
  • Custom type definitions
slide-63
SLIDE 63

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
  • Custom type definitions
  • Modular architecture (module - function - basic block - instruction)
slide-64
SLIDE 64

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
  • Custom type definitions
  • Modular architecture (module - function - basic block - instruction)
  • Textual format & in-memory format
slide-65
SLIDE 65

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
  • Custom type definitions
  • Modular architecture (module - function - basic block - instruction)
  • Textual format & in-memory format
  • Built-in parser and verifier
slide-66
SLIDE 66

DLVM IR

  • Full static single assignment (SSA) form
  • Control flow graph (CFG) and basic blocks with arguments
  • Custom type definitions
  • Modular architecture (module - function - basic block - instruction)
  • Textual format & in-memory format
  • Built-in parser and verifier
  • Robust unit testing via LLVM Integrated Tester (lit) and FileCheck
slide-67
SLIDE 67

DLVM IR

Module Type Definition Function Basic Block Basic Block Basic Block Function Basic Block Basic Block Basic Block

slide-68
SLIDE 68

module "my_module" // Module declaration stage raw // Raw stage IR in the compilation phase struct $Classifier { #w: <784 x 10 x f32>, #b: <1 x 10 x f32>, } type $MyClassifier = $Classifier

slide-69
SLIDE 69

module "my_module" // Module declaration stage raw // Raw stage IR in the compilation phase struct $Classifier { #w: <784 x 10 x f32>, #b: <1 x 10 x f32>, } type $MyClassifier = $Classifier func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-70
SLIDE 70

module "my_module" // Module declaration stage raw // Raw stage IR in the compilation phase struct $Classifier { #w: <784 x 10 x f32>, #b: <1 x 10 x f32>, } type $MyClassifier = $Classifier func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> conditional true: bool then ‘b0() else 'b1() 'b0(): return %0.1: <1 x 10 x f32> ‘b1(): return 0: <1 x 10 x f32> }

slide-71
SLIDE 71

Transformations: Differentiation & Optimizations

slide-72
SLIDE 72

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-73
SLIDE 73

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference wrt 1, 2] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)
slide-74
SLIDE 74

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference wrt 1, 2] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)

Differentiation Pass
 Canonicalizes every gradient function declaration in an IR module

slide-75
SLIDE 75

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): }

slide-76
SLIDE 76

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): }

Copy instructions from original function

slide-77
SLIDE 77

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> }

slide-78
SLIDE 78

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> }

Generate adjoint code

slide-79
SLIDE 79

func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> %0.2 = transpose %x: <1 x 784 x f32> %0.3 = multiply %0.2: <1 x 784 x f32>, 1: f32 return (%0.3: <1 x 10 x f32>, 1: f32): (<1 x 10 x f32>, f32) } func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-80
SLIDE 80

func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> %0.2 = transpose %x: <1 x 784 x f32> %0.3 = multiply %0.2: <1 x 784 x f32>, 1: f32 return (%0.3: <1 x 10 x f32>, 1: f32): (<1 x 10 x f32>, f32) } func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

Algebra Simplification Pass

slide-81
SLIDE 81

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> %0.2 = transpose %x: <1 x 784 x f32> return (%0.2: <1 x 10 x f32>, 1: f32): (<1 x 10 x f32>, f32) }

slide-82
SLIDE 82

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> %0.2 = transpose %x: <1 x 784 x f32> return (%0.2: <1 x 10 x f32>, 1: f32): (<1 x 10 x f32>, f32) }

Dead Code Elimination Pass

slide-83
SLIDE 83

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>) {

'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = transpose %x: <1 x 784 x f32> return (%0.0: <1 x 10 x f32>, 1: f32): (<1 x 10 x f32>, f32) }

slide-84
SLIDE 84

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-85
SLIDE 85

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference from 0] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)
slide-86
SLIDE 86

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference from 0] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

Configurable gradient declaration
 from: selecting which output to differentiate in tuple return

slide-87
SLIDE 87

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference from 0 wrt 1, 2] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)

Configurable gradient declaration
 from: selecting which output to differentiate in tuple return
 wrt: with respect to arguments 1 & 2

slide-88
SLIDE 88

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference from 0 wrt 1, 2 keeping 0] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>, <1 x 10 x f32>)

Configurable gradient declaration
 from: selecting which output to differentiate in tuple return
 wrt: with respect to arguments 1 & 2
 keeping: keeping original output

slide-89
SLIDE 89

func @inference: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @inference from 0 wrt 1, 2 keeping 0 seedable] func @inference_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>, <1 x 10 x f32>)

Configurable gradient declaration
 from: selecting which output to differentiate in tuple return
 wrt: with respect to arguments 1 & 2
 keeping: keeping original output
 seedable: allow passing in back-propagated gradients as seed

slide-90
SLIDE 90

func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = apply @f(%x, %w, %b): (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> %0.1 = tanh %0.0: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-91
SLIDE 91

func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = apply @f(%x, %w, %b): (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> %0.1 = tanh %0.0: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @g wrt 1, 2] func @g_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> (<784 x 10 x f32>, <1 x 10 x f32>)

slide-92
SLIDE 92

func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = apply @f(%x, %w, %b): (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> %0.1 = tanh %0.0: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @g wrt 1, 2] func @g_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> (<784 x 10 x f32>, <1 x 10 x f32>)

slide-93
SLIDE 93

func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = apply @f(%x, %w, %b): (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> %0.1 = tanh %0.0: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @g wrt 1, 2] func @g_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> (<784 x 10 x f32>, <1 x 10 x f32>) [gradient @f wrt 1, 2 seedable] func @f_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)
slide-94
SLIDE 94

func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = apply @f(%x, %w, %b): (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> %0.1 = tanh %0.0: <1 x 10 x f32> return %0.1: <1 x 10 x f32> } [gradient @g wrt 1, 2] func @g_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> (<784 x 10 x f32>, <1 x 10 x f32>) [gradient @f wrt 1, 2 seedable] func @f_grad: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)

Seed

slide-95
SLIDE 95

Compilation Phases

slide-96
SLIDE 96

Compilation Phases

DLVM

slide-97
SLIDE 97

Compilation Phases

DLVM

stage raw

slide-98
SLIDE 98

Compilation Phases

DLVM Analyses & Verification

stage raw

slide-99
SLIDE 99

Compilation Phases

DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability

stage raw

slide-100
SLIDE 100

Compilation Phases

Differentiation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability

stage raw

slide-101
SLIDE 101

Compilation Phases

Differentiation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability

stage raw stage optimizable

slide-102
SLIDE 102

Compilation Phases

Differentiation Optimizations DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability

stage raw stage optimizable

slide-103
SLIDE 103

Compilation Phases

Differentiation Optimizations DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable

slide-104
SLIDE 104

Compilation Phases

Differentiation Optimizations Compute Generation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable

slide-105
SLIDE 105

Compilation Phases

Differentiation Optimizations Compute Generation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable stage compute

slide-106
SLIDE 106

Compilation Phases

Differentiation Optimizations Compute Generation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable stage compute

Compute Scheduling

slide-107
SLIDE 107

Compilation Phases

Differentiation Optimizations Compute Generation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable stage compute

Compute Scheduling

stage schedule

slide-108
SLIDE 108

Compilation Phases

Differentiation Optimizations LLGen Compute Generation DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable stage compute

Compute Scheduling

stage schedule

slide-109
SLIDE 109

Compilation Phases

Differentiation Optimizations LLGen Compute Generation LLVM DLVM Analyses & Verification

Dominance Side Effects Type Checking Differentiability Linear Alg Fusion AD Checkpointing Algebra Simp Constant Prop Dead Code Elim Com Subexpr Elim Matrix Chaining

stage raw stage optimizable stage compute

Compute Scheduling

stage schedule

slide-110
SLIDE 110

DLVM

slide-111
SLIDE 111

DSL Compiler Infrastructure

slide-112
SLIDE 112

DSL Compiler Infrastructure

slide-113
SLIDE 113

DSL Compiler Infrastructure

slide-114
SLIDE 114

DSL Compiler Infrastructure

  • NN as program, not a graph
  • Static analysis
  • Type safety
  • Naturalness
  • Lightweight modular staging
  • Compiler magic
slide-115
SLIDE 115

NNKit: Staged DSL in Swift

slide-116
SLIDE 116

NNKit

slide-117
SLIDE 117

NNKit

  • It’s a prototype!
slide-118
SLIDE 118

NNKit

  • It’s a prototype!
  • Tensor computation embedded in host language
slide-119
SLIDE 119

NNKit

  • It’s a prototype!
  • Tensor computation embedded in host language
  • Type safety
slide-120
SLIDE 120

NNKit

  • It’s a prototype!
  • Tensor computation embedded in host language
  • Type safety
  • Generates DLVM IR on the fly
slide-121
SLIDE 121

Language

slide-122
SLIDE 122

Language

  • Statically ranked tensors
  • T, Tensor1D<T>, Tensor2D<T>, Tensor3D<T>, Tensor4D<T>
slide-123
SLIDE 123

Language

  • Statically ranked tensors
  • T, Tensor1D<T>, Tensor2D<T>, Tensor3D<T>, Tensor4D<T>
  • Type wrapper for staging – Rep<Wrapped>
  • Rep<Float>, Rep<Tensor1D<Float>>, Rep<Tensor2D<T>>
slide-124
SLIDE 124

Language

  • Statically ranked tensors
  • T, Tensor1D<T>, Tensor2D<T>, Tensor3D<T>, Tensor4D<T>
  • Type wrapper for staging – Rep<Wrapped>
  • Rep<Float>, Rep<Tensor1D<Float>>, Rep<Tensor2D<T>>
  • Operator overloading
  • func + <T: Numeric>(_: Rep<T>, _: Rep<T>) -> Rep<T>
  • func • (_: Rep<Tensor2D<T>>, _: Rep<Tensor2D<T>>) -> Rep<Tensor2D<T>>
slide-125
SLIDE 125

Language

slide-126
SLIDE 126

Language

  • Lambda abstraction
  • func lambda<T, U>(_ f: (Rep<T>) -> Rep<U>) -> Rep<(T) -> U>
slide-127
SLIDE 127

Language

  • Lambda abstraction
  • func lambda<T, U>(_ f: (Rep<T>) -> Rep<U>) -> Rep<(T) -> U>
  • Function application
  • subscript<T, U>(arg: Rep<T>) -> Rep<U> where Wrapped == (T) -> U
  • subscript<T, U>(arg: T) -> U where Wrapped == (T) -> U


// JIT DLVM IR

slide-128
SLIDE 128

Staged Evaluation

slide-129
SLIDE 129

Staged Evaluation

Rep<(Float2D) -> Float2D>

slide-130
SLIDE 130

Staged Evaluation

Rep<(Float2D) -> Float2D> (Float2D) -> Float2D

slide-131
SLIDE 131

Staged Evaluation

slide-132
SLIDE 132

Staged Evaluation

Rep<(Float2D) -> Float2D>

slide-133
SLIDE 133

Staged Evaluation

Expression Staging

Rep<(Float2D) -> Float2D>

slide-134
SLIDE 134

Staged Evaluation

Expression Staging Shape Specialization

Rep<(Float2D) -> Float2D>

slide-135
SLIDE 135

Staged Evaluation

Expression Staging Shape Specialization DLVM IR Generation

Rep<(Float2D) -> Float2D>

slide-136
SLIDE 136

Staged Evaluation

Expression Staging Shape Specialization DLVM IR Generation

Rep<(Float2D) -> Float2D>

slide-137
SLIDE 137

Staged Evaluation

Expression Staging Shape Specialization DLVM IR Generation

Rep<(Float2D) -> Float2D>

slide-138
SLIDE 138

Staged Evaluation

Expression Staging Shape Specialization DLVM IR Generation Function Reification

Rep<(Float2D) -> Float2D>

slide-139
SLIDE 139

Staged Evaluation

Expression Staging Shape Specialization DLVM IR Generation Function Reification

Rep<(Float2D) -> Float2D> (Float2D) -> Float2D

slide-140
SLIDE 140

typealias Float2D = Tensor2D<Float> let inference: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let foo: Parameters = … inference[..., foo.w, foo.b]

slide-141
SLIDE 141

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let inference: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let foo: Parameters = … inference[..., foo.w, foo.b]

slide-142
SLIDE 142

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let foo: Parameters = … inference[..., foo.w, foo.b]

slide-143
SLIDE 143

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let foo: Parameters = … inference[..., foo.w, foo.b]

slide-144
SLIDE 144

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let params: Parameters = …

slide-145
SLIDE 145

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let params: Parameters = … let x: Float2D = [[0.0, 1.0]]

slide-146
SLIDE 146

typealias Float2D = Tensor2D<Float> struct Parameters { var w: Float2D var b: Float2D } let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let params: Parameters = … let x: Float2D = [[0.0, 1.0]] f[x, params.w, params.b] // ==> result

slide-147
SLIDE 147

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } f[x, w, b] // x: 1x784, w: 784x10, b: 1x10

slide-148
SLIDE 148

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } f[x, w, b] // x: 1x784, w: 784x10, b: 1x10 func @f: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>) -> <1 x 10 x f32> { 'entry(%x: <1 x 784 x f32>, %w: <784 x 10 x f32>, %b: <1 x 10 x f32>): %0.0 = dot %x: <1 x 784 x f32>, %w: <784 x 10 x f32> %0.1 = add %0.0: <1 x 10 x f32>, %b: <1 x 10 x f32> return %0.1: <1 x 10 x f32> }

slide-149
SLIDE 149

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b }

slide-150
SLIDE 150

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) }

slide-151
SLIDE 151

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2))

slide-152
SLIDE 152

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2)) // ∇g : Rep<(Float2D, Float2D, Float2D) -> (Float2D, Float2D)>

slide-153
SLIDE 153

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2)) // ∇g : Rep<(Float2D, Float2D, Float2D) -> (Float2D, Float2D)> ∇g[x, w, b] // ==> ( ∂g/∂w, ∂g/∂b )

slide-154
SLIDE 154

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2)) // ∇g : Rep<(Float2D, Float2D, Float2D) -> (Float2D, Float2D)> ∇g[x, w, b] // ==> ( ∂g/∂w, ∂g/∂b ) [gradient @f wrt 1, 2] func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)
slide-155
SLIDE 155

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2)) // ∇g : Rep<(Float2D, Float2D, Float2D) -> (Float2D, Float2D)> ∇g[x, w, b] // ==> ( ∂g/∂w, ∂g/∂b ) [gradient @f wrt 1, 2] func @g: (<1 x 784 x f32>, <784 x 10 x f32>, <1 x 10 x f32>)

  • > (<784 x 10 x f32>, <1 x 10 x f32>)
slide-156
SLIDE 156

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2)) // ∇g : Rep<(Float2D, Float2D, Float2D) -> (Float2D, Float2D)> ∇g[x, w, b] // ==> ( ∂g/∂w, ∂g/∂b )

slide-157
SLIDE 157

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2), seedable: true) // ∇g : Rep<(Float2D, Float2D, Float2D, Float2D) -> (Float2D, Float2D)>

slide-158
SLIDE 158

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2), seedable: true, keeping: (0)) // ∇g : Rep<(Float2D, Float2D, Float2D, Float2D) -> (Float2D, Float2D, Float2D)>

slide-159
SLIDE 159

let f: Rep<(Float2D, Float2D, Float2D) -> Float2D> = lambda { x, w, b in x • w + b } let g = lambda { x, w, b in let linear = f[x, w, b] return tanh(linear) } let ∇g = gradient(of: g, withRespectTo: (1, 2), seedable: true, keeping: (0)) // ∇g : Rep<(Float2D, Float2D, Float2D, Float2D) -> (Float2D, Float2D, Float2D)> ∇g[x, w, b, ∂h_∂g] // ==> ( ∂h/∂w, ∂h/∂b, g(x,w,b) )

slide-160
SLIDE 160

Safe language

Libraries DSL Compiler Infrastructure

slide-161
SLIDE 161

Swift

Libraries NNKit DLVM

slide-162
SLIDE 162

Swift

Libraries NNKit DLVM

DLVM is written in Swift!

slide-163
SLIDE 163

PL & Compilers + ML

slide-164
SLIDE 164

PL & Compilers + ML

  • Programs, not just a data flow graph
slide-165
SLIDE 165

PL & Compilers + ML

  • Programs, not just a data flow graph
  • Type safety
slide-166
SLIDE 166

PL & Compilers + ML

  • Programs, not just a data flow graph
  • Type safety
  • Ahead-of-time AD
slide-167
SLIDE 167

PL & Compilers + ML

  • Programs, not just a data flow graph
  • Type safety
  • Ahead-of-time AD
  • Code generation
slide-168
SLIDE 168

dlvm.org

slide-169
SLIDE 169

DLVM