Added Wasm example

This commit is contained in:
David Beazley 2019-02-17 19:48:18 -06:00
parent 5a7f8ab652
commit a728a23adf
3 changed files with 1199 additions and 0 deletions

245
example/wasm/expr.py Normal file
View File

@ -0,0 +1,245 @@
# -----------------------------------------------------------------------------
# expr.py
#
# Proof-of-concept encoding of functions/expressions into Wasm.
#
# This file implements a mini-language for writing Wasm functions as expressions.
# It only supports integers.
#
# Here's a few examples:
#
# # Some basic function definitions
# add(x, y) = x + y;
# mul(x, y) = x * y;
# dsquare(x, y) = mul(x, x) + mul(y, y);
#
# # A recursive function
# fact(n) = if n < 1 then 1 else n*fact(n-1);
#
# The full grammar:
#
# functions : functions function
# | function
#
# function : NAME ( parms ) = expr ;
#
# expr : expr + expr
# | expr - expr
# | expr * expr
# | expr / expr
# | expr < expr
# | expr <= expr
# | expr > expr
# | expr >= expr
# | expr == expr
# | expr != expr
# | ( expr )
# | NAME (exprs)
# | if expr then expr else expr
# | NUMBER
#
# Note: This is implemented as one-pass compiler with no intermediate AST.
# Some of the grammar rules have to be written in a funny way to make this
# work. If doing this for real, I'd probably build an AST and construct
# Wasm code through AST walking.
# -----------------------------------------------------------------------------
import sys
sys.path.append('../..')
from sly import Lexer, Parser
import wasm
class ExprLexer(Lexer):
tokens = { NAME, NUMBER, PLUS, TIMES, MINUS, DIVIDE, LPAREN, RPAREN, COMMA,
LT, LE, GT, GE, EQ, NE, IF, THEN, ELSE, ASSIGN, SEMI }
ignore = ' \t'
# Tokens
NAME = r'[a-zA-Z_][a-zA-Z0-9_]*'
NAME['if'] = IF
NAME['then'] = THEN
NAME['else'] = ELSE
NUMBER = r'\d+'
# Special symbols
PLUS = r'\+'
MINUS = r'-'
TIMES = r'\*'
DIVIDE = r'/'
LPAREN = r'\('
RPAREN = r'\)'
COMMA = r','
LE = r'<='
LT = r'<'
GE = r'>='
GT = r'>'
EQ = r'=='
NE = r'!='
ASSIGN = r'='
SEMI = ';'
# Ignored pattern
ignore_newline = r'\n+'
ignore_comment = r'#.*\n'
# Extra action for newlines
def ignore_newline(self, t):
self.lineno += t.value.count('\n')
def error(self, t):
print("Illegal character '%s'" % t.value[0])
self.index += 1
class ExprParser(Parser):
tokens = ExprLexer.tokens
precedence = (
('left', IF, ELSE),
('left', EQ, NE, LT, LE, GT, GE),
('left', PLUS, MINUS),
('left', TIMES, DIVIDE),
('right', UMINUS)
)
def __init__(self):
self.functions = { }
self.module = wasm.Module()
@_('functions function')
def functions(self, p):
pass
@_('function')
def functions(self, p):
pass
@_('function_decl ASSIGN expr SEMI')
def function(self, p):
self.function.block_end()
self.function = None
@_('NAME LPAREN parms RPAREN')
def function_decl(self, p):
self.locals = { name:n for n, name in enumerate(p.parms) }
self.function = self.module.add_function(p.NAME, [wasm.i32]*len(p.parms), [wasm.i32])
self.functions[p.NAME] = self.function
@_('NAME LPAREN RPAREN')
def function_decl(self, p):
self.locals = { }
self.function = self.module.add_function(p.NAME, [], [wasm.i32])
self.functions[p.NAME] = self.function
@_('parms COMMA parm')
def parms(self, p):
return p.parms + [p.parm]
@_('parm')
def parms(self, p):
return [ p.parm ]
@_('NAME')
def parm(self, p):
return p.NAME
@_('expr PLUS expr')
def expr(self, p):
self.function.i32.add()
@_('expr MINUS expr')
def expr(self, p):
self.function.i32.sub()
@_('expr TIMES expr')
def expr(self, p):
self.function.i32.mul()
@_('expr DIVIDE expr')
def expr(self, p):
self.function.i32.div_s()
@_('expr LT expr')
def expr(self, p):
self.function.i32.lt_s()
@_('expr LE expr')
def expr(self, p):
self.function.i32.le_s()
@_('expr GT expr')
def expr(self, p):
self.function.i32.gt_s()
@_('expr GE expr')
def expr(self, p):
self.function.i32.ge_s()
@_('expr EQ expr')
def expr(self, p):
self.function.i32.eq()
@_('expr NE expr')
def expr(self, p):
self.function.i32.ne()
@_('MINUS expr %prec UMINUS')
def expr(self, p):
pass
@_('LPAREN expr RPAREN')
def expr(self, p):
pass
@_('NUMBER')
def expr(self, p):
self.function.i32.const(int(p.NUMBER))
@_('NAME')
def expr(self, p):
self.function.local.get(self.locals[p.NAME])
@_('NAME LPAREN exprlist RPAREN')
def expr(self, p):
self.function.call(self.functions[p.NAME])
@_('NAME LPAREN RPAREN')
def expr(self, p):
self.function.call(self.functions[p.NAME])
@_('IF expr thenexpr ELSE expr')
def expr(self, p):
self.function.block_end()
@_('exprlist COMMA expr')
def exprlist(self, p):
pass
@_('expr')
def exprlist(self, p):
pass
@_('startthen expr')
def thenexpr(self, p):
self.function.else_start()
@_('THEN')
def startthen(self, p):
self.function.if_start(wasm.i32)
if __name__ == '__main__':
import sys
if len(sys.argv) != 2:
raise SystemExit(f'Usage: {sys.argv[0]} module')
lexer = ExprLexer()
parser = ExprParser()
parser.parse(lexer.tokenize(open(sys.argv[1]).read()))
name = sys.argv[1].split('.')[0]
parser.module.write_wasm(name)
parser.module.write_html(name)
print(f'Wrote: {name}.wasm')
print(f'Wrote: {name}.html')
print('Use python3 -m http.server to test')

25
example/wasm/test.e Normal file
View File

@ -0,0 +1,25 @@
# Experimental Wasm function examples.
# To run:
#
# 1. First run python3 expr.py test.e
# 2. Use python3 -m http.server
#
# Go to a browser and visit http://localhost:8000/test.html.
# From the browser, open the Javascript console. Try executing
# the functions from there.
#
# Some basic functions
add(x,y) = x+y;
sub(x,y) = x-y;
mul(x,y) = x*y;
div(x,y) = x/y;
# A function calling other functions
dsquare(x,y) = mul(x,x) + mul(y,y);
# A conditional
minval(a, b) = if a < b then a else b;
# Some recursive functions
fact(n) = if n <= 1 then 1 else n*fact(n-1);
fib(n) = if n < 2 then 1 else fib(n-1) + fib(n-2);

929
example/wasm/wasm.py Normal file
View File

@ -0,0 +1,929 @@
# wasm.py
#
# Experimental builder for Wasm binary encoding. Use at your own peril.
#
# Author: David Beazley (@dabeaz)
# Copyright (C) 2019
# http://www.dabeaz.com
import struct
import enum
from collections import defaultdict
import json
# ------------------------------------------------------------
# Low level encoding of values
def encode_unsigned(value):
'''
Produce an LEB128 encoded unsigned integer.
'''
bits = bin(value)[2:]
if len(bits) % 7:
bits = '0'*(7 - len(bits) % 7) + bits
parts = [ bits[i:i+7] for i in range(0,len(bits), 7) ]
parts = [ parts[0], *['1'+p for p in parts[1:]] ]
parts = [ int(p, 2) for p in parts ]
return bytes(parts[::-1])
assert encode_unsigned(624485) == bytes([0xe5, 0x8e, 0x26])
def encode_signed(value):
'''
Produce a LEB128 encoded signed integer.
'''
if value > 0:
return encode_unsigned(value)
bits = bin(~(~0 << value.bit_length()) & value)[2:]
bits = '1'+ '0'*(value.bit_length() - len(bits)) + bits
if len(bits) % 7:
bits = '1'*(7 - len(bits) % 7) + bits
return encode_unsigned(int(bits,2))
assert encode_signed(-624485) == bytes([0x9b, 0xf1, 0x59])
def encode_float64(value):
return struct.pack('<d', value)
def encode_float32(value):
return struct.pack('<f', value)
def encode_name(value):
data = value.encode('utf-8')
return encode_vector(list(data[i:i+1] for i in range(len(data))))
def encode_vector(items):
'''
Items is a list of encoded values
'''
if isinstance(items, bytes):
return encode_unsigned(len(items)) + items
else:
return encode_unsigned(len(items)) + b''.join(items)
# ------------------------------------------------------------
# Instruction encoding enums.
#
# Wasm defines 4 core data types [i32, i64, f32, f64]. These type
# names are used in various places (specifying functions, globals,
# etc.). However, the type names are also used as a namespace for
# type-specific instructions such as i32.add. We're going to use
# Python enums to set up this arrangement in a clever way that
# makes it possible to do both of these tasks.
# Metaclass for instruction encoding categories. The class itself
# can be used as an integer when encoding instructions.
class HexEnumMeta(enum.EnumMeta):
def __int__(cls):
return int(cls._encoding)
__index__ = __int__
def __repr__(cls):
return cls.__name__
@classmethod
def __prepare__(meta, name, bases, encoding=0):
return super().__prepare__(name, bases)
@staticmethod
def __new__(meta, clsname, bases, methods, encoding=0):
cls = super().__new__(meta, clsname, bases, methods)
cls._encoding = encoding
return cls
class HexEnum(enum.IntEnum):
def __repr__(self):
return f'<{self!s}: 0x{self:x}>'
HexEnum.__class__ = HexEnumMeta
class i32(HexEnum, encoding=0x7f):
eqz = 0x45
eq = 0x46
ne = 0x47
lt_s = 0x48
lt_u = 0x49
gt_s = 0x4a
gt_u = 0x4b
le_s = 0x4c
le_u = 0x4d
ge_s = 0x4e
ge_u = 0x4f
clz = 0x67
ctz = 0x68
popcnt = 0x69
add = 0x6a
sub = 0x6b
mul = 0x6c
div_s = 0x6d
div_u = 0x6e
rem_s = 0x6f
rem_u = 0x70
and_ = 0x71
or_ = 0x72
xor = 0x73
shl = 0x74
shr_s = 0x75
shr_u = 0x76
rotl = 0x77
rotr = 0x78
wrap_i64 = 0xa7
trunc_f32_s = 0xa8
trunc_f32_u = 0xa9
trunc_f64_s = 0xaa
trunc_f64_u = 0xab
reinterpret_f32 = 0xbc
load = 0x28
load8_s = 0x2c
load8_u = 0x2d
load16_s = 0x2e
load16_u = 0x2f
store = 0x36
store8 = 0x3a
store16 = 0x3b
const = 0x41
class i64(HexEnum, encoding=0x7e):
eqz = 0x50
eq = 0x51
ne = 0x52
lt_s = 0x53
lt_u = 0x54
gt_s = 0x55
gt_u = 0x56
le_s = 0x57
le_u = 0x58
ge_s = 0x59
ge_u = 0x5a
clz = 0x79
ctz = 0x7a
popcnt = 0x7b
add = 0x7c
sub = 0x7d
mul = 0x7e
div_s = 0x7f
div_u = 0x80
rem_s = 0x81
rem_u = 0x82
and_ = 0x83
or_ = 0x84
xor = 0x85
shl = 0x86
shr_s = 0x87
shr_u = 0x88
rotl = 0x89
rotr = 0x8a
extend_i32_s = 0xac
extend_i32_u = 0xad
trunc_f32_s = 0xae
trunc_f32_u = 0xaf
trunc_f64_s = 0xb0
trunc_f64_u = 0xb1
reinterpret_f64 = 0xbd
load = 0x29
load8_s = 0x30
load8_u = 0x31
load16_s = 0x32
load16_u = 0x33
load32_s = 0x34
load32_u = 0x35
store = 0x37
store8 = 0x3c
store16 = 0x3d
store32 = 0x3e
const = 0x42
class f32(HexEnum, encoding=0x7d):
eq = 0x5b
ne = 0x5c
lt = 0x5d
gt = 0x5e
le = 0x5f
ge = 0x60
abs = 0x8b
neg = 0x8c
ceil = 0x8d
floor = 0x8e
trunc = 0x8f
nearest = 0x90
sqrt = 0x91
add = 0x92
sub = 0x93
mul = 0x94
div = 0x95
min = 0x96
max = 0x97
copysign = 0x98
convert_i32_s = 0xb2
convert_i32_u = 0xb3
convert_i64_s = 0xb4
convert_i64_u = 0xb5
demote_f64 = 0xb6
reinterpret_i32 = 0xbe
load = 0x2a
store = 0x38
const = 0x43
class f64(HexEnum, encoding=0x7c):
eq = 0x61
ne = 0x62
lt = 0x63
gt = 0x64
le = 0x65
ge = 0x66
abs = 0x99
neg = 0x9a
ceil = 0x9b
floor = 0x9c
trunc = 0x9d
nearest = 0x9e
sqrt = 0x9f
add = 0xa0
sub = 0xa1
mul = 0xa2
div = 0xa3
min = 0xa4
max = 0xa5
copysign = 0xa6
convert_i32_s = 0xb7
convert_i32_u = 0xb8
convert_i64_s = 0xb9
convert_i64_u = 0xba
promote_f32 = 0xbb
reinterpret_i64 = 0xbf
load = 0x2b
store = 0x39
const = 0x44
class local(HexEnum):
get = 0x20
set = 0x21
tee = 0x22
class global_(HexEnum):
get = 0x23
set = 0x24
global_.__name__ = 'global'
# Special void type for block returns
void = 0x40
# ------------------------------------------------------------
def encode_function_type(parameters, results):
'''
parameters is a vector of value types
results is a vector value types
'''
enc_parms = bytes(parameters)
enc_results = bytes(results)
return b'\x60' + encode_vector(enc_parms) + encode_vector(enc_results)
def encode_limits(min, max=None):
if max is None:
return b'\x00' + encode_unsigned(min)
else:
return b'\x01' + encode_unsigned(min) + encode_unsigned(max)
def encode_table_type(elemtype, min, max=None):
return b'\x70' + encode_limits(min, max)
def encode_global_type(value_type, mut=True):
return bytes([value_type, mut])
# ----------------------------------------------------------------------
# Instruction builders
#
# Wasm instructions are grouped into different namespaces. For example:
#
# i32.add()
# local.get()
# memory.size()
# ...
#
# The classes that follow implement the namespace for different instruction
# categories.
# Builder for the local.* namespace
class SubBuilder:
def __init__(self, builder):
self._builder = builder
def _append(self, instr):
self._builder._code.append(instr)
class LocalBuilder(SubBuilder):
def get(self, localidx):
self._append([local.get, *encode_unsigned(localidx)])
def set(self, localidx):
self._append([local.set, *encode_unsigned(localidx)])
def tee(self, localidx):
self._append([local.tee, *encode_unsigned(localidx)])
class GlobalBuilder(SubBuilder):
def get(self, glob):
if isinstance(glob, int):
globidx = glob
else:
globidx = glob.idx
self._append([global_.get, *encode_unsigned(globidx)])
def set(self, glob):
if isinstance(glob, int):
globidx = glob
else:
globidx = glob.idx
self._append([global_.set, *encode_unsigned(globidx)])
class MemoryBuilder(SubBuilder):
def size(self):
self._append([0x3f, 0x00])
def grow(self):
self._append([0x40, 0x00])
class OpBuilder(SubBuilder):
_optable = None # To be supplied by subclasses
# Memory ops
def load(self, align, offset):
self._append([self._optable.load, *encode_unsigned(align), *encode_unsigned(offset)])
def load8_s(self, align, offset):
self._append([self._optable.load8_s, *encode_unsigned(align), *encode_unsigned(offset)])
def load8_u(self, align, offset):
self._append([self._optable.load8_u, *encode_unsigned(align), *encode_unsigned(offset)])
def load16_s(self, align, offset):
self._append([self._optable.load16_s, *encode_unsigned(align), *encode_unsigned(offset)])
def load16_u(self, align, offset):
self._append([self._optable.load16_u, *encode_unsigned(align), *encode_unsigned(offset)])
def load32_s(self, align, offset):
self._append([self._optable.load32_s, *encode_unsigned(align), *encode_unsigned(offset)])
def load32_u(self, align, offset):
self._append([self._optable.load32_u, *encode_unsigned(align), *encode_unsigned(offset)])
def store(self, align, offset):
self._append([self._optable.store, *encode_unsigned(align), *encode_unsigned(offset)])
def store8(self, align, offset):
self._append([self._optable.store8, *encode_unsigned(align), *encode_unsigned(offset)])
def store16(self, align, offset):
self._append([self._optable.store16, *encode_unsigned(align), *encode_unsigned(offset)])
def store32(self, align, offset):
self._append([self._optable.store32, *encode_unsigned(align), *encode_unsigned(offset)])
def __getattr__(self, key):
def call():
self._append([getattr(self._optable, key)])
return call
class I32OpBuilder(OpBuilder):
_optable = i32
def const(self, value):
self._append([self._optable.const, *encode_signed(value)])
class I64OpBuilder(OpBuilder):
_optable = i64
def const(self, value):
self._append([self._optable.const, *encode_signed(value)])
class F32OpBuilder(OpBuilder):
_optable = f32
def const(self, value):
self._append([self._optable.const, *encode_float32(value)])
class F64OpBuilder(OpBuilder):
_optable = f64
def const(self, value):
self._append([self._optable.const, *encode_float64(value)])
def _flatten(instr):
for x in instr:
if isinstance(x, list):
yield from _flatten(x)
else:
yield x
# High-level class that allows instructions to be easily encoded.
class InstructionBuilder:
def __init__(self):
self._code = [ ]
self.local = LocalBuilder(self)
self.global_ = GlobalBuilder(self)
self.i32 = I32OpBuilder(self)
self.i64 = I64OpBuilder(self)
self.f32 = F32OpBuilder(self)
self.f64 = F64OpBuilder(self)
# Control-flow stack.
self._control = [ None ]
def __iter__(self):
return iter(self._code)
# Resolve a human-readable label into control-stack index
def _resolve_label(self, label):
if isinstance(label, int):
return label
index = self._control.index(label)
return len(label) - 1 - index
# Control flow instructions
def unreachable(self):
self._code.append([0x01])
def nop(self):
self._code.append([0x01])
def block_start(self, result_type, label=None):
self._code.append([0x02, result_type])
self._control.append(label)
return len(self._control)
def block_end(self):
self._code.append([0x0b])
self._control.pop()
def loop_start(self, result_type, label=None):
self._code.append([0x03, result_type])
self._control.append(label)
return len(self._control)
def if_start(self, result_type, label=None):
self._code.append([0x04, result_type])
self._control.append(label)
def else_start(self):
self._code.append([0x05])
def br(self, label):
labelidx = self._resolve_label(label)
self._code.append([0x0c, *encode_unsigned(labelidx)])
def br_if(self, label):
labelidx = self._resolve_label(label)
self._code.append([0x0d, *encode_unsigned(labelidx)])
def br_table(self, labels, label):
enc_labels = [encode_unsigned(self._resolve_label(idx)) for idx in labels]
self._code.append([0x0e, *encode_vector(enc_labels), *encode_unsigned(self._resolve_label(label))])
def return_(self):
self._code.append([0x0f])
def call(self, func):
if isinstance(func, (ImportFunction,Function)):
self._code.append([0x10, *encode_unsigned(func._idx)])
else:
self._code.append([0x10, *encode_unsigned(func)])
def call_indirect(self, typesig):
if isinstance(typesig, Type):
typeidx = typesig.idx
else:
typeidx = typesig
self._code.append([0x11, *encode_unsigned(typeidx), 0x00])
def drop(self):
self._code.append([0x1a])
def select(self):
self._code.append([0x1b])
class Type:
def __init__(self, parms, results, idx):
self.parms = parms
self.results = results
self.idx = idx
def __repr__(self):
return f'{self.parms!r} -> {self.results!r}'
class ImportFunction:
def __init__(self, name, typesig, idx):
self._name = name
self._typesig = typesig
self._idx = idx
def __repr__(self):
return f'ImportFunction({self._name}, {self._typesig}, {self._idx})'
class Function(InstructionBuilder):
def __init__(self, name, typesig, idx, export=True):
super().__init__()
self._name = name
self._typesig = typesig
self._locals = list(typesig.parms)
self._export = export
self._idx = idx
def __repr__(self):
return f'Function({self._name}, {self._typesig}, {self._idx})'
# Allocate a new local variable of a given type
def alloc(self, valuetype):
self._locals.append(valuetype)
return len(self.locals) - 1
class ImportGlobal:
def __init__(self, name, valtype, idx):
self.name = name
self.valtype = valtype
self.idx = idx
def __repr__(self):
return f'ImportGlobal({self.name}, {self.valtype}, {self.idx})'
class Global:
def __init__(self, name, valtype, initializer, idx):
self.name = name
self.valtype = valtype
self.initializer = initializer
self.idx = idx
def __repr__(self):
return f'Global({self.name}, {self.valtype}, {self.initializer}, {self.idx})'
class Module:
def __init__(self):
# Vector of function type signatures. Signatures are reused
# if more than one function has the same signature.
self.type_section = []
# Vector of imported entities. These can be functions, globals,
# tables, and memories
self.import_section = []
# There are 4 basic entities within a Wasm file. Functions,
# globals, memories, and tables. Each kind of entity is
# stored in a separate list and is indexed by an integer
# index starting at 0. Imported entities must always
# go before entities defined in the Wasm module itself.
self.funcidx = 0
self.globalidx = 0
self.memoryidx = 0
self.tableidx = 0
self.function_section = [] # Vector of typeidx
self.global_section = [] # Vector of globals
self.table_section = [] # Vector of tables
self.memory_section = [] # Vector of memories
# Exported entities. A module may export functions, globals,
# tables, and memories
self.export_section = [] # Vector of exports
# Optional start function. A function that executes upon loading
self.start_section = None # Optional start function
# Initialization of table elements
self.element_section = []
# Code section for function bodies.
self.code_section = []
# Data section contains data segments
self.data_section = []
# List of function objects (to help with encoding)
self.functions = []
# Output for JS/Html
self.js_exports = "";
self.html_exports = "";
self.js_imports = defaultdict(dict)
def add_type(self, parms, results):
enc = encode_function_type(parms, results)
if enc in self.type_section:
return Type(parms, results, self.type_section.index(enc))
else:
self.type_section.append(enc)
return Type(parms, results, len(self.type_section) - 1)
def import_function(self, module, name, parms, results):
if len(self.function_section) > 0:
raise RuntimeError('function imports must go before first function definition')
typesig = self.add_type(parms, results)
code = encode_name(module) + encode_name(name) + b'\x00' + encode_unsigned(typesig.idx)
self.import_section.append(code)
self.js_imports[module][name] = f"function: {typesig}"
self.funcidx += 1
return ImportFunction(f'{module}.{name}', typesig, self.funcidx - 1)
def import_table(self, module, name, elemtype, min, max=None):
code = encode_name(module) + encode_name(name) + b'\x01' + encode_table_type(elemtype, min, max)
self.import_section.append(code)
self.js_imports[module][name] = "table:"
self.tableidx += 1
return self.tableidx - 1
def import_memtype(self, module, name, min, max=None):
code = encode_name(module) + encode_name(name) + b'\x02' + encode_limits(min, max)
self.import_section.append(code)
self.js_imports[module][name] = "memory:"
self.memoryidx += 1
return self.memoryidx - 1
def import_global(self, module, name, value_type):
if len(self.global_section) > 0:
raise RuntimeError('global imports must go before first global definition')
code = encode_name(module) + encode_name(name) + b'\x03' + encode_global_type(value_type, False)
self.import_section.append(code)
self.js_imports[module][name] = f"global: {value_type}"
self.globalidx += 1
return ImportGlobal(f'{module}.{name}', value_type, self.globalidx - 1)
def add_function(self, name, parms, results, export=True):
typesig = self.add_type(parms, results)
func = Function(name, typesig, self.funcidx, export)
self.funcidx += 1
self.functions.append(func)
self.function_section.append(encode_unsigned(typesig.idx))
self.html_exports += f'<p><tt>{name}({", ".join(str(p) for p in parms)}) -> {results[0]!s}</tt></p>\n'
return func
def add_table(self, elemtype, min, max=None):
self.table_section.append(encode_table_type(elemtype, min, max))
self.tableidx += 1
return self.tableidx - 1
def add_memory(self, min, max=None):
self.memory_section.append(encode_limits(min, max))
self.memoryidx += 1
return self.memoryidx - 1
def add_global(self, name, value_type, initializer, mutable=True, export=True):
code = encode_global_type(value_type, mutuable)
expr = InstructionBuilder()
getattr(expr, str(valtype)).const(initializer)
expr.finalize()
code += expr._code
self.global_section.append(code)
if export:
self.export_global(name, self.globalidx)
self.globalidx += 1
return Global(name, value_type, initializer, self.globalidx - 1)
def export_function(self, name, funcidx):
code = encode_name(name) + b'\x00' + encode_unsigned(funcidx)
self.export_section.append(code)
self.js_exports += f'window.{name} = results.instance.exports.{name};\n'
def export_table(self, name, tableidx):
code = encode_name(name) + b'\x01' + encode_unsigned(tableidx)
self.export_section.append(code)
def export_memory(self, name, memidx):
code = encode_name(name) + b'\x02' + encode_unsigned(memidx)
self.export_section.append(code)
def export_global(self, name, globalidx):
code = encode_name(name) + b'\x03' + encode_unsigned(globalidx)
self.export_section.append(code)
def start_function(self, funcidx):
self.start = encode_unsigned(funcidx)
def add_element(self, tableidx, expr, funcidxs):
code = encode_unsigned(tableidx) + expr.code
code += encode_vector([encode_unsigned(i) for i in funcidxs])
self.element_section.append(code)
def add_function_code(self, locals, expr):
# Locals is a list of valtypes [i32, i32, etc...]
# expr is an expression representing the actual code (InstructionBuilder)
locs = [ encode_unsigned(1) + bytes([loc]) for loc in locals ]
locs_code = encode_vector(locs)
func_code = locs_code + bytes(_flatten(expr))
code = encode_unsigned(len(func_code)) + func_code
self.code_section.append(code)
def add_data(self, memidx, expr, data):
# data is bytes
code = encode_unsigned(memidx) + expr.code + encode_vector([data[i:i+1] for i in range(len(data))])
self.data_section.append(code)
def _encode_section_vector(self, sectionid, contents):
if not contents:
return b''
contents_code = encode_vector(contents)
code = bytes([sectionid]) + encode_unsigned(len(contents_code)) + contents_code
return code
def encode(self):
for func in self.functions:
self.add_function_code(func._locals, func._code)
if func._export:
self.export_function(func._name, func._idx)
# Encode the whole module
code = b'\x00\x61\x73\x6d\x01\x00\x00\x00'
code += self._encode_section_vector(1, self.type_section)
code += self._encode_section_vector(2, self.import_section)
code += self._encode_section_vector(3, self.function_section)
code += self._encode_section_vector(4, self.table_section)
code += self._encode_section_vector(5, self.memory_section)
code += self._encode_section_vector(6, self.global_section)
code += self._encode_section_vector(7, self.export_section)
if self.start_section:
code += encode_unsigned(8) + self.start_section
code += self._encode_section_vector(9, self.element_section)
code += self._encode_section_vector(10, self.code_section)
code += self._encode_section_vector(11, self.data_section)
return code
def write_wasm(self, modname):
with open(f'{modname}.wasm', 'wb') as f:
f.write(self.encode())
def write_html(self, modname):
with open(f'{modname}.html', 'wt') as f:
f.write(js_template.format(
module=modname,
imports=json.dumps(self.js_imports, indent=4),
exports=self.js_exports,
exports_html=self.html_exports,
)
)
js_template = '''
<html>
<body>
<script>
var imports = {imports};
fetch("{module}.wasm").then(response =>
response.arrayBuffer()
).then(bytes =>
WebAssembly.instantiate(bytes, imports)
).then(results => {{
{exports}
}});
</script>
<h3>module {module}</h3>
<p>
The following exports are made. Access from the JS console.
</p>
{exports_html}
</body>
</html>
'''
def test1():
mod = Module()
# An external function import. Note: All imports MUST go first.
# Indexing affects function indexing for functions defined in the module.
# Import some functions from JS
# math_sin = mod.import_function('util', 'sin', [f64], [f64])
# math_cos = mod.import_function('util', 'cos', [f64], [f64])
# Import a function from another module entirely
# fact = mod.import_function('recurse', 'fact', [i32], [i32])
# Import a global variable (from JS?)
# FOO = mod.import_global('util', 'FOO', f64)
# A more complicated function
dsquared_func = mod.add_function('dsquared', [f64, f64], [f64])
dsquared_func.local.get(0)
dsquared_func.local.get(0)
dsquared_func.f64.mul()
dsquared_func.local.get(1)
dsquared_func.local.get(1)
dsquared_func.f64.mul()
dsquared_func.f64.add()
dsquared_func.block_end()
# A function calling another function
distance = mod.add_function('distance', [f64, f64], [f64])
distance.local.get(0)
distance.local.get(1)
distance.call(dsquared_func)
distance.f64.sqrt()
distance.block_end()
# A function calling out to JS
# ext = mod.add_function('ext', [f64, f64], [f64])
# ext.local.get(0)
# ext.call(math_sin)
# ext.local.get(1)
# ext.call(math_cos)
# ext.f64.add()
# ext.block_end()
# A function calling across modules
# tenf = mod.add_function('tenfact', [i32], [i32])
# tenf.local.get(0)
# tenf.call(fact)
# tenf.i32.const(10)
# tenf.i32.mul()
# tenf.block_end()
# A function accessing an imported global variable
# gf = mod.add_function('gf', [f64], [f64])
# gf.global_.get(FOO)
# gf.local.get(0)
# gf.f64.mul()
# gf.block_end()
# Memory
mod.add_memory(1)
mod.export_memory('memory', 0)
# Function that returns a byte value
getval = mod.add_function('getval', [i32], [i32])
getval.local.get(0)
getval.i32.load8_u(0, 0)
getval.block_end()
# Function that sets a byte value
setval = mod.add_function('setval', [i32,i32], [i32])
setval.local.get(0) # Memory address
setval.local.get(1) # value
setval.i32.store8(0,0)
setval.i32.const(1)
setval.block_end()
return mod
def test2():
mod = Module()
fact = mod.add_function('fact', [i32], [i32])
fact.local.get(0)
fact.i32.const(1)
fact.i32.lt_s()
fact.if_start(i32)
fact.i32.const(1)
fact.else_start()
fact.local.get(0)
fact.local.get(0)
fact.i32.const(1)
fact.i32.sub()
fact.call(fact)
fact.i32.mul()
fact.block_end()
fact.block_end()
return mod
if __name__ == '__main__':
mod = test1()
mod.write_wasm('test')
mod.write_html('test')