Commit a435811e authored by Lyude Paul's avatar Lyude Paul

wip: Add clause header encoding

parent a4dac95e
......@@ -40,7 +40,9 @@ except Exception:
sys.exit(-1)
class ParsingException(Exception):
pass
def __init__(self, msg):
self.msg = msg
super().__init__(msg)
class OpParserBase:
""" Base for op parser classes """
......@@ -66,7 +68,7 @@ class OpParserBase:
def __init__(self, expected, got):
super().__init__('Expected %d sources, got %d' % (expected, got))
def parse_op(self, reg_file, srcs):
def parse_op(self, clause, reg_file, srcs):
if self.src_cnt and len(srcs) != self.src_cnt:
raise self.SrcCountException(self.src_cnt, len(srcs))
......@@ -93,6 +95,7 @@ class SrcOpParserBase(OpParserBase):
opcode >>= (src_cnt - 1) * 3
super().__init__(name, opcode, src_cnt)
class fma:
NAME = 'FMA'
DST_MNEMONIC = 'T0'
......@@ -325,9 +328,11 @@ class add:
src_cnt = 2
class ATestOpParser(SrcOpParser):
def parse_op(self, reg_file, srcs):
def parse_op(self, clause, reg_file, srcs):
clause.clause_type = Clause.ClauseType.ALPHA_TEST
reg_file.const_port = Bits(length=8, uint=5)
return super().parse_op(reg_file, srcs)
return super().parse_op(clause, reg_file, srcs)
class LoadAttrOpParser(OpParser):
pass
......@@ -358,18 +363,20 @@ class add:
return '<BlendDescriptor at 0x%x; idx=%d>' % (
id(self), self.idx)
def parse_op(self, reg_file, srcs):
def parse_op(self, clause, reg_file, srcs):
loc_token = srcs.pop(0)
if not isinstance(loc_token, tuple) or loc_token[0] != 'location':
raise ParsingException("Invalid src '%s' (must be loc_token)" % loc_token)
reg_file.disable_port(1)
clause.clause_type = Clause.ClauseType.BLEND
try:
reg_file.const_port = self.BlendDescriptor(int(loc_token[1]))
except ValueError as e:
raise ParsingException("Invalid blend descriptor '%s'" % loc_token[1])
return super().parse_op(reg_file, srcs)
return super().parse_op(clause, reg_file, srcs)
def __init__(self):
pass
......@@ -587,7 +594,8 @@ class Parser:
s = self.INLINE_COMMENT.sub('', s).rstrip()
if s != '':
self.last_line = i, s
yield (i, s)
yield (i+1, s)
class UniformToken:
def __init__(self, idx, high32):
......@@ -602,6 +610,7 @@ class UniformToken:
return '<UniformToken %d (%s)>' % (
self.canonical_idx, 'High32' if self.high32 else 'Low32')
class ImmediateToken:
class ReadType(Enum):
FULL32 = None
......@@ -629,6 +638,7 @@ class ImmediateToken:
return '<ImmediateToken at 0x%x 0x%x; type=%s>' % (
id(self), self.value, self.read_type.name)
class OpResult(Enum):
PREV_FMA = 'T0'
PREV_ADD = 'T1'
......@@ -644,6 +654,7 @@ class OpResult(Enum):
return bitstring.pack('uint:3', field_val)
class RegisterToken:
"""
A token for a register that hasn't been assigned to a RegisterFile yet
......@@ -657,6 +668,7 @@ class RegisterToken:
def __eq__(self, other):
return self.idx == other.idx
class RegisterFile:
class Register:
def __init__(self, idx, write_stage=None):
......@@ -681,6 +693,7 @@ class RegisterFile:
id(self)
)
class ControlField(IntEnum):
WRITE_FMA_P2 = 1
WRITE_FMA_P2_READ_P3 = 3
......@@ -731,8 +744,7 @@ class RegisterFile:
return
elif self.__const_port is not None:
if isinstance(self.__const_port,
Clause.Instruction.PendingImmediateSlot):
if isinstance(self.__const_port, Instruction.PendingImmediateSlot):
if (not isinstance(value, ImmediateSlot) and
value is not ImmediateZeroSlot):
raise self.ConstPortInUse(self.__const_port, value)
......@@ -900,6 +912,7 @@ class RegisterFile:
port_fields[2],
const_field)
class ConstantSrc:
def __init__(self, high32):
self.high32 = high32
......@@ -911,6 +924,7 @@ class ConstantSrc:
return '<ConstantSrc (%s) at 0x%x>' % (
'high' if self.high32 else 'low', id(self))
class ImmediateZeroSlot:
@classmethod
def encode_const_field(cls):
......@@ -921,6 +935,7 @@ class ImmediateZeroSlot:
assert token.read_type != ImmediateToken.ReadType.FULL32
return ConstantSrc(token.read_type == ImmediateToken.ReadType.HIGH64)
class ImmediateSlot:
""" A slot that can hold one 64 bit const, or two 32 bit consts. """
......@@ -952,6 +967,7 @@ class ImmediateSlot:
return "<ImmediateSlot #%d at 0x%x; contents=%s>" % (
self.idx, id(self), self.contents)
class Uniform:
def __init__(self, idx):
self.idx = idx
......@@ -964,7 +980,7 @@ class Uniform:
return ConstantSrc(token.high32)
def encode_const_field(self):
return bitstring.pack('uint:1=1, uint:7', self.idx)
return Bits(length=8, uint=0x80 | self.idx)
def canonical_idx_str(self):
canonical_idx = self.idx * 2
......@@ -973,127 +989,314 @@ class Uniform:
def __repr__(self):
return "<Uniform %s at 0x%x>" % (self.canonical_idx_str(), id(self))
class Clause:
class DataRegister:
def __init__(self, idx):
self.idx = idx
def __repr__(self):
return '<Data register %d at %s>' % (self.idx, hex(id(self)))
class Instruction:
class ImmediateCountError(ParsingException):
def __init__(self, token):
if token.read_type == ImmediateToken.ReadType.FULL32:
mod_str = ""
else:
mod_str = ".%s" % self.read_type.value
class Instruction:
class ImmediateCountError(ParsingException):
def __new__(self, token):
if token.read_type == ImmediateToken.ReadType.FULL32:
mod_str = ""
else:
mod_str = ".%s" % self.read_type.value
super().__init__("No space left for immediate 0x%x%s" % (token.value,
mod_str))
super().__init__("No space left for immediate 0x%x%s" % (token.value,
mod_str))
class PendingImmediateSlot:
def __init__(self):
self.__contents = []
class PendingImmediateSlot:
def __init__(self):
self.__contents = []
def __iter__(self):
return self.__contents.__iter__()
def __iter__(self):
return self.__contents.__iter__()
def __len__(self):
return self.__contents.__len__()
def __len__(self):
return self.__contents.__len__()
def __repr__(self):
return '<PendingImmediateSlot at 0x%x; contents=%s>' % (
id(self), self.__contents)
def __repr__(self):
return '<PendingImmediateSlot at 0x%x; contents=%s>' % (
id(self), self.__contents)
@property
def contents(self):
return self.__contents
@property
def contents(self):
return self.__contents
@property
def bitlen(self):
return sum(t.bitlen for t in self.contents)
@property
def bitlen(self):
return sum(t.bitlen for t in self.contents)
def add_immediate(self, token):
if token in self.__contents:
return
def add_immediate(self, token):
if token in self.__contents:
return
if self.bitlen + token.bitlen > 64:
raise ParsingException('Too many constants for one instruction cycle')
self.__contents.append(token)
if self.bitlen + token.bitlen > 64:
raise ParsingException('Too many constants for one instruction cycle')
self.__contents.append(token)
def __init__(self, first):
self.first = first
self.reg_file = RegisterFile()
self.fma = None
self.add = None
self.writes = {fma: None, add: None}
def __init__(self, first):
self.first = first
self.reg_file = RegisterFile()
self.fma = None
self.add = None
self.writes = {fma: None, add: None}
def __repr__(self):
return '<Instruction at 0x%x; fma=%s, add=%s, first=%s, writes=%s, reg_file=%s>' % (
id(self), self.fma, self.add, self.first, self.writes,
self.reg_file
)
def __repr__(self):
return '<Instruction at %s; fma=%s, add=%s, first=%s, writes=%s, reg_file=%s>' % (
hex(id(self)), self.fma, self.add, self.first,
self.writes, self.reg_file
)
def add_pending_immediate(self, token):
"""
Add a pending immediate token this instruction and check that we
will still be able to encode this instruction.
"""
pending_slot = self.reg_file.const_port
if not self.has_pending_immediates():
pending_slot = self.PendingImmediateSlot()
self.reg_file.const_port = pending_slot
def add_pending_immediate(self, token):
"""
Add a pending immediate token this instruction and check that we
will still be able to encode this instruction.
"""
pending_slot = self.reg_file.const_port
if not self.has_pending_immediates():
pending_slot = self.PendingImmediateSlot()
self.reg_file.const_port = pending_slot
pending_slot.add_immediate(token)
pending_slot.add_immediate(token)
def resolve_immediates(self, clause, slot=None):
if not slot:
slot = clause._get_immediate_slot(self.reg_file.const_port.contents)
def resolve_immediates(self, clause, slot=None):
if not slot:
slot = clause._get_immediate_slot(self.reg_file.const_port.contents)
self.reg_file.const_port = slot
self.reg_file.const_port = slot
# Replace ImmediateTokens in our src list with proper immediate
# sources
for stage in self.stages:
for idx, src in enumerate(stage.srcs):
if isinstance(src, ImmediateToken):
stage.srcs[idx] = slot.get_src(src)
# Replace ImmediateTokens in our src list with proper immediate
# sources
for stage in self.stages:
for idx, src in enumerate(stage.srcs):
if isinstance(src, ImmediateToken):
stage.srcs[idx] = slot.get_src(src)
def has_pending_immediates(self):
return isinstance(self.reg_file.const_port,
self.PendingImmediateSlot)
def has_pending_immediates(self):
return isinstance(self.reg_file.const_port,
self.PendingImmediateSlot)
def add_uniform(self, token):
const_port = self.reg_file.const_port
if not isinstance(const_port, Uniform):
const_port = Uniform(token.idx)
self.reg_file.const_port = const_port
def add_uniform(self, token):
const_port = self.reg_file.const_port
if not isinstance(const_port, Uniform):
const_port = Uniform(token.idx)
self.reg_file.const_port = const_port
return const_port.get_src(token)
return const_port.get_src(token)
@property
def pending_stage(self):
if self.fma is None:
return fma
elif self.add is None:
return add
@property
def pending_stage(self):
if self.fma is None:
return fma
elif self.add is None:
return add
@property
def stages(self):
return (self.fma, self.add)
@property
def stages(self):
return (self.fma, self.add)
class Clause:
class ClauseType(IntEnum):
# XXX This isn't going to be used just yet
NONE = 0
SSBO_LOAD = 5
SSBO_STORE = 6
BLEND = 9
ALPHA_TEST = 13
MAX_CONSTS_ALLOWED = 5 # FIXME: maybe make this 6 later?
def __init__(self, idx, header):
self.instructions = []
self.idx = idx
self.header = header
self.__data_reg = None
def _header_flag(func):
old_func = func
def func(self, value):
if value is None:
raise ParsingException("Missing arguments for flag")
old_func(self, value)
return func
def _bool_header_flag(func):
old_func = func
def func(self, unused):
if unused is not None:
raise ParsingException("This flag takes no arguments")
old_func(self)
return func
class ClauseIdError(ParsingException):
def __init__(self, id_str, msg=None):
super().__init__("Invalid clause id '%s'%s" % (
id_str, ': %s' % msg if msg else ''
))
@_header_flag
def _parse_header_id(self, id_):
if id_[-1] != 'u':
raise self.ClauseIdError(id_)
self.immediate_slots = []
self.__pending_writes = {fma: None, add: None}
self.__pending_inst = None
try:
id_ = int(id_[:-1])
except ValueError as e:
raise self.ClauseIdError(id_) from e
if id_ < 0 or id_ > 5:
raise self.ClauseIdError(id_, "ID must be between 0 and 5")
self.scoreboard_entry = id_
@_header_flag
def _parse_header_next_wait(self, next_wait):
for id_ in next_wait.split(','):
try:
id_ = int(id_.strip())
except ValueError as e:
raise self.ClauseIdError(id_) from e
if id_ < 0 or id_ > 7:
raise self.ClauseIdError(id_, "ID must be between 0 and 7")
self.scoreboard_deps.append(id_)
@_bool_header_flag
def _parse_header_data_reg_barrier(self):
self.data_reg_write_barrier = True
@_bool_header_flag
def _parse_header_eos(self):
self.end_of_shader = True
@_bool_header_flag
def _parse_header_nbb(self):
self.back_to_back = False
def _validate_header_nbb(self):
if not hasattr(self, 'branch_conditional'):
raise ParsingException(
'Neither branch-cond or branch-uncond specified')
@_bool_header_flag
def _parse_header_we(self):
self.elide_writes = True
class BackToBackError(ParsingException):
def __init__(self):
super().__init__("Flag isn't applicable without nbb")
@_bool_header_flag
def _parse_header_branch_cond(self):
self.branch_conditional = True
def _validate_header_branch_cond(self):
if self.back_to_back:
raise self.BackToBackError()
if not self.branch_conditional:
raise ParsingException("Can't use this flag with branch-uncond")
@_bool_header_flag
def _parse_header_branch_uncond(self):
self.branch_conditional = False
def _validate_header_branch_uncond(self):
if self.back_to_back:
raise self.BackToBackError()
if self.branch_conditional:
raise ParsingException("Can't use this flag with branch-cond")
@_bool_header_flag
def _parse_header_unk0(self):
self.unk0 = True
@_header_flag
def _parse_header_unk1(self, value):
try:
value = int(value)
if value.bit_length() > 2:
raise ParsingException("Invalid value (known bitlength is 2)")
except ValueError as e:
raise ParsingException("Must be a value between 0-3") from e
self.unk1 = value
@_bool_header_flag
def _parse_header_unk3(self):
self.unk3 = True
@_header_flag
def _parse_header_unk4(self):
raise NotImplemented("it's 11 bits, the disasm doesn't have a format for this yet")
def _parse_header_flag(self, flag, value):
"""
Try looking up the parser for a given header flag and parse it
"""
try:
func_name = '_parse_header_%s' % flag.replace('-', '_')
getattr(self, func_name)(value)
except AttributeError as e:
raise ParsingException("Unknown flag '%s' in clause header" % flag) from e
except ParsingException as e:
raise ParsingException("Invalid value for flag '%s': %s" % (flag, e.msg)) from e
def _validate_header_flag(self, flag):
"""
Try looking up the validator function for a given header flag. If there
is one, run the validator to ensure that the header flag and it's value
are still valid after all of the other flags have been parsed
"""
func_name = '_validate_header_%s' % flag.replace('-', '_')
if not hasattr(self, func_name):
return
try:
getattr(self, func_name)()
except ParsingException as e:
raise ParsingException("Invalid use of flag '%s': %s" % (flag, e.msg)) from e
def _parse_header(self, header):
flags = set()
for flag, value in header.items():
flags.add(flag)
self._parse_header_flag(flag, value)
for flag in flags:
self._validate_header_flag(flag)
def encode_scoreboard_deps(self):
encoded = 0
for dep in self.scoreboard_deps:
encoded |= 1 << dep
return Bits(length=8, uint=encoded)
def encode_header(self):
return Bits(bitstring.pack(
"""
bool=unk4,
uint:4=next_clause_type,
bool=unk3,
uint:4=clause_type,
uint:3=scoreboard_index,
bits:8=scoreboard_deps,
uint:6=data_reg,
bool=data_reg_write_barrier,
bool=branch_cond,
bool=elide_writes,
uint:2=unk1,
bool=no_end_of_shader,
bool=back_to_back,
uint:11=unk0
""",
unk4=self.unk4,
next_clause_type=self.next_clause_type,
unk3=self.unk3,
clause_type=self.clause_type,
scoreboard_index=self.scoreboard_entry if self.scoreboard_entry else 0,
scoreboard_deps=self.encode_scoreboard_deps(),
data_reg=self.data_reg if self.data_reg else 0,
data_reg_write_barrier=self.data_reg_write_barrier,
branch_cond=True if self.back_to_back else self.branch_conditional,
elide_writes=self.elide_writes,
unk1=self.unk1,
no_end_of_shader=not self.end_of_shader,
back_to_back=self.back_to_back,
unk0=0 # TODO
))
# return bitstring.pack('bool, uint:4=0, bool, uint:4=0
@property
def data_reg(self):
......@@ -1101,10 +1304,29 @@ class Clause:
@data_reg.setter
def data_reg(self, val):
if self.__data_reg is not None:
raise ParsingException("Only one data register allowed (already had %s)" % (self.__data_reg))
if self.__data_reg is not None and val != self.__data_reg:
raise ParsingException(("Only one data register allowed (already "
"had R%d)") % self.__data_reg)
self.__data_reg = val
@property
def clause_type(self):
return self.__clause_type
@clause_type.setter
def clause_type(self, value):
assert isinstance(self.__clause_type, self.ClauseType)
if (self.__clause_type is not self.ClauseType.NONE and
self.__clause_type != value):
raise ParsingException(
("Would need the clause's instruction type to be %s, but the"
" instruction type is already %s") % (
self.__clause_type.name, value.name)
)
self.__clause_type = value
def is_first_instruction(self):
return len(self.instructions) == 0
......@@ -1152,7 +1374,7 @@ class Clause:
def get_pending_instruction(self):
if not self.__pending_inst:
self.__pending_inst = self.Instruction(self.is_first_instruction())
self.__pending_inst = Instruction(self.is_first_instruction())
return self.__pending_inst
def add_instruction_stage(self, op, dst, srcs):
......@@ -1183,9 +1405,9 @@ class Clause:
inst.reg_file.add_reg_write(write.idx, prev_stage)
self.__pending_writes = {fma: None, add: None}
inst.fma = parser.parse_op(inst.reg_file, srcs)
inst.fma = parser.parse_op(self, inst.reg_file, srcs)
else:
inst.add = parser.parse_op(inst.reg_file, srcs)
inst.add = parser.parse_op(self, inst.reg_file, srcs)
# We attempt to match what the compiler does here. Basically: If
# this instruction uses enough immediates to fill an entire
......@@ -1207,12 +1429,35 @@ class Clause:
if dst:
self.__pending_writes[stage] = dst
def __repr__(self):
return '<Clause #%d @ %d>' % (self.idx, id(self))
def dump(self):
print('Clause #%d' % self.idx)
print('\tHeader: %s' % self.header)
print('\tHeader:')
print(f'\t\tEncoded: {self.encode_header().uint:#x} {self.encode_header().bin}')
print('\t\tType: %s' % self.clause_type.name)
print('\t\tNext type: %s' % self.next_clause_type.name)
if self.scoreboard_entry is not None:
print(f'\t\tScoreboard entry: {1 << self.scoreboard_entry:#08b} ({self.scoreboard_entry:d})')
if hasattr(self, 'id'):
print('\t\tID: %d' % self.id)
print('\t\tScoreboard dependencies: %s%s' % (
self.encode_scoreboard_deps().bin,
' (%s)' % self.scoreboard_deps if self.scoreboard_deps else ''
))
if hasattr(self, 'next_wait'):
print('\t\tScoreboard deps: %s' % self.next_wait)
print('\t\tData register write barrier: %s' % self.data_reg_write_barrier)
print('\t\tEnd of shader: %s' % self.end_of_shader)
print('\t\tBack to back: %s' % self.back_to_back)
if not self.back_to_back:
print('\t\tBranch conditional: %s' % self.branch_conditional)
print('\t\tElide writes: %s' % self.elide_writes)
print('\t\tunk0: %s' % self.unk0)
print('\t\tunk1: %s' % self.unk1)
print('\t\tunk3: %s' % self.unk3)
print('\t\tunk4: %s' % self.unk4)
print('\tInstructions:')
for i in self.instructions:
print('\t\t%s' % i)
......@@ -1281,6 +1526,34 @@ class Clause:
inst.resolve_immediates(self, slot)
self.immediate_slots.append(slot)
def __repr__(self):
return '<Clause #%d @ %d>' % (self.idx, id(self))
def __init__(self, idx, header):
self.instructions = []
self.immediate_slots = []
self.__pending_writes = {fma: None, add: None}
self.__pending_inst = None
self.idx = idx
self.back_to_back = True
self.branch_conditional = False
self.end_of_shader = False
self.elide_writes = False
self.data_reg_write_barrier = False
self.__data_reg = None
self.scoreboard_deps = []
self.scoreboard_entry = None
self.__clause_type = self.ClauseType.NONE
self.next_clause_type = self.ClauseType.NONE
self.unk0 = False
self.unk1 = False
self.unk3 = False
self.unk4 = False
self._parse_header(header)
CLAUSE_START = re.compile(r'^clause_([0-9]+):$')
def scan_clause_start(string):
......@@ -1302,6 +1575,8 @@ def scan_clause_header_tokens(string):
ret = dict()
for m in matches:
key, val = m.groups()
if key in ret:
raise ParsingException("%s flag specified twice" % key)
ret[key] = val
return ret
......@@ -1412,7 +1687,7 @@ def scan_op_tokens(string):
match = DATA_REGISTER.match(data_reg)
if not match:
raise ParsingException("Invalid data register '%s'" % data_reg)
data_reg = Clause.DataRegister(int(match.group(1)))
data_reg = int(match.group(1))
else:
data_reg = None
......@@ -1440,8 +1715,10 @@ if __name__ == '__main__':
p = Parser(args.file, args.verbose)
pp = pprint.PrettyPrinter(indent=4)
itr = iter(p.lines)
try:
unexpected_eof = False
clauses = []
try:
for i, l in itr: