Skip to content
Snippets Groups Projects
Commit 949c3b55 authored by Dylan Baker's avatar Dylan Baker Committed by Marge Bot
Browse files

util/glsl2spirv: add type annotations


Which are all clean

Reviewed-by: default avatarLuis Felipe Strano Moraes <luis.strano@gmail.com>
Part-of: <mesa/mesa!19449>
parent c01cd8ca
No related branches found
No related tags found
No related merge requests found
......@@ -21,15 +21,29 @@
# Converts GLSL shader to SPIR-V library
from __future__ import annotations
import argparse
import subprocess
import os
import typing as T
if T.TYPE_CHECKING:
class Arguments(T.Protocol):
input: str
output: str
create_entry: T.Optional[str]
glsl_ver: T.Optional[str]
Olib: bool
extra: T.Optional[str]
vn: str
stage: str
class ShaderCompileError(RuntimeError):
def __init__(self, *args):
super(ShaderCompileError, self).__init__(*args)
def get_args():
def get_args() -> Arguments:
parser = argparse.ArgumentParser()
parser.add_argument('input', help="Name of input file.")
parser.add_argument('output', help="Name of output file.")
......@@ -64,7 +78,7 @@ def get_args():
return args
def create_include_guard(lines, filename):
def create_include_guard(lines: T.List[str], filename: str) -> T.List[str]:
filename = filename.replace('.', '_')
upper_name = filename.upper()
......@@ -81,7 +95,7 @@ def create_include_guard(lines, filename):
return guard_head + lines + guard_tail
def convert_to_static_variable(lines, varname):
def convert_to_static_variable(lines: T.List[str], varname: str) -> T.List[str]:
for idx, l in enumerate(lines):
if l.find(varname) != -1:
lines[idx] = "static " + lines[idx]
......@@ -89,7 +103,7 @@ def convert_to_static_variable(lines, varname):
raise RuntimeError(f'Did not find {varname}, this is unexpected')
def override_version(lines, glsl_version):
def override_version(lines: T.List[str], glsl_version: str) -> T.List[str]:
for idx, l in enumerate(lines):
if l.find('#version ') != -1:
lines[idx] = "#version {}\n".format(glsl_version)
......@@ -97,7 +111,7 @@ def override_version(lines, glsl_version):
raise RuntimeError('Did not find #version directive, this is unexpected')
def postprocess_file(args):
def postprocess_file(args: Arguments) -> None:
with open(args.output, "r") as r:
lines = r.readlines()
......@@ -111,7 +125,7 @@ def postprocess_file(args):
w.writelines(lines)
def preprocess_file(args, origin_file):
def preprocess_file(args: Arguments, origin_file: T.TextIO) -> str:
with open(origin_file.name + ".copy", "w") as copy_file:
lines = origin_file.readlines()
......@@ -126,7 +140,7 @@ def preprocess_file(args, origin_file):
return copy_file.name
def process_file(args):
def process_file(args: Arguments) -> None:
with open(args.input, "r") as infile:
copy_file = preprocess_file(args, infile)
......@@ -169,7 +183,7 @@ def process_file(args):
os.remove(copy_file)
def main():
def main() -> None:
args = get_args()
process_file(args)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment