Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import pickle | |
| from datasets import load_dataset | |
| from plaid.containers.sample import Sample | |
| import numpy as np | |
| import pyrender | |
| from trimesh import Trimesh | |
| import matplotlib as mpl | |
| import matplotlib.cm as cm | |
| from utils_inference import infer | |
| import os | |
| # switch to "osmesa" or "egl" before loading pyrender | |
| os.environ["PYOPENGL_PLATFORM"] = "egl" | |
| hf_dataset = load_dataset("PLAID-datasets/AirfRANS_remeshed", split="all_samples") | |
| file = open('training_data.pkl', 'rb') | |
| training_data = pickle.load(file) | |
| file.close() | |
| train_ids = hf_dataset.description['split']['ML4PhySim_Challenge_train'] | |
| out_fields_names = hf_dataset.description['out_fields_names'] | |
| in_scalars_names = hf_dataset.description['in_scalars_names'] | |
| out_scalars_names = hf_dataset.description['out_scalars_names'] | |
| nb_samples = len(hf_dataset) | |
| # <h2><b><a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP</b> demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2> | |
| # <a href='https://arxiv.org/abs/2305.12871' target='_blank'><b>MMGP paper</b>, | |
| _HEADER_ = ''' | |
| <h2><b>MMGP demo on the <a href='https://huggingface.co/datasets/PLAID-datasets/AirfRANS_remeshed' target='_blank'><b>AirfRANS_remeshed dataset</b></b></h2> | |
| ''' | |
| _HEADER_2 = ''' | |
| The model is already trained. The morphing is the same as the one used in the [MMGP paper](https://arxiv.org/abs/2305.12871), | |
| but is much less involved than the one used in the winning entry of the [ML4PhySim competition](https://www.codabench.org/competitions/1534/). | |
| The training set has 103 samples and is the one used in this competition (some evaluations are out-of-distribution). | |
| The inference takes approx 5 seconds, and is done from scratch (no precomputation is used during the inference when evaluating samples). | |
| This means that the morphing and the finite element interpolations are re-done at each evaluation. | |
| After choosing a sample id, please change the field name in the dropdown menu to update the visualization. | |
| ''' | |
| def round_num(num)->str: | |
| return '%s' % float('%.3g' % num) | |
| def compute_inference(sample_id_str): | |
| sample_id = int(sample_id_str) | |
| sample_ = hf_dataset[sample_id]["sample"] | |
| plaid_sample = Sample.model_validate(pickle.loads(sample_)) | |
| prediction = infer(hf_dataset, sample_id, training_data) | |
| reference = {fieldn:plaid_sample.get_field(fieldn) for fieldn in out_fields_names} | |
| nodes = plaid_sample.get_nodes() | |
| if nodes.shape[1] == 2: | |
| nodes__ = np.zeros((nodes.shape[0],nodes.shape[1]+1)) | |
| nodes__[:,:-1] = nodes | |
| nodes = nodes__ | |
| triangles = plaid_sample.get_elements()['TRI_3'] | |
| trimesh = Trimesh(vertices = nodes, faces = triangles) | |
| file = open('computed_inference.pkl', 'wb') | |
| pickle.dump([trimesh, reference, prediction], file) | |
| file.close() | |
| str__ = f"Training sample {sample_id_str}" | |
| if sample_id in train_ids: | |
| str__ += " (in the training set)\n\n" | |
| else: | |
| str__ += " (not in the training set)\n\n" | |
| str__ += str(plaid_sample)+"\n" | |
| if len(hf_dataset.description['in_scalars_names'])>0: | |
| str__ += "\nInput scalars:\n" | |
| for sname in hf_dataset.description['in_scalars_names']: | |
| str__ += f"- {sname}: {round_num(plaid_sample.get_scalar(sname))}\n" | |
| str__ += f"\nNumber of nodes in the mesh: {nodes.shape[0]}" | |
| return str__ | |
| def show_pred(fieldn): | |
| file = open('computed_inference.pkl', 'rb') | |
| data = pickle.load(file) | |
| file.close() | |
| trimesh, reference, prediction = data[0], data[1], data[2] | |
| ref = reference[fieldn] | |
| pred = prediction[fieldn] | |
| norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
| cmap = cm.seismic#cm.coolwarm | |
| m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
| vertex_colors = m.to_rgba(pred)[:,:3] | |
| trimesh.visual.vertex_colors = vertex_colors | |
| mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
| # compose scene | |
| scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
| camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
| light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
| scene.add(mesh, pose= np.eye(4)) | |
| scene.add(light, pose= np.eye(4)) | |
| scene.add(camera, pose=[[ 1, 0, 0, 1], | |
| [ 0, 1, 0, 0], | |
| [ 0, 0, 1, 6], | |
| [ 0, 0, 0, 1]]) | |
| # render scene | |
| r = pyrender.OffscreenRenderer(1024, 1024) | |
| color, _ = r.render(scene) | |
| return color | |
| def show_ref(fieldn): | |
| file = open('computed_inference.pkl', 'rb') | |
| data = pickle.load(file) | |
| file.close() | |
| trimesh, reference, prediction = data[0], data[1], data[2] | |
| ref = reference[fieldn] | |
| norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
| cmap = cm.seismic#cm.coolwarm | |
| m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
| vertex_colors = m.to_rgba(ref)[:,:3] | |
| trimesh.visual.vertex_colors = vertex_colors | |
| mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
| # compose scene | |
| scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
| camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
| light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
| scene.add(mesh, pose= np.eye(4)) | |
| scene.add(light, pose= np.eye(4)) | |
| scene.add(camera, pose=[[ 1, 0, 0, 1], | |
| [ 0, 1, 0, 0], | |
| [ 0, 0, 1, 6], | |
| [ 0, 0, 0, 1]]) | |
| # render scene | |
| r = pyrender.OffscreenRenderer(1024, 1024) | |
| color, _ = r.render(scene) | |
| return color | |
| def show_err(fieldn): | |
| file = open('computed_inference.pkl', 'rb') | |
| data = pickle.load(file) | |
| file.close() | |
| trimesh, reference, prediction = data[0], data[1], data[2] | |
| ref = reference[fieldn] | |
| pred = prediction[fieldn] | |
| norm = mpl.colors.Normalize(vmin=np.min(ref), vmax=np.max(ref)) | |
| cmap = cm.seismic#cm.coolwarm | |
| m = cm.ScalarMappable(norm=norm, cmap=cmap) | |
| vertex_colors = m.to_rgba(pred-ref)[:,:3] | |
| trimesh.visual.vertex_colors = vertex_colors | |
| mesh = pyrender.Mesh.from_trimesh(trimesh, smooth=False) | |
| # compose scene | |
| scene = pyrender.Scene(ambient_light=[.1, .1, .3], bg_color=[0, 0, 0]) | |
| camera = pyrender.PerspectiveCamera( yfov=np.pi / 3.0) | |
| light = pyrender.DirectionalLight(color=[1,1,1], intensity=1000.) | |
| scene.add(mesh, pose= np.eye(4)) | |
| scene.add(light, pose= np.eye(4)) | |
| scene.add(camera, pose=[[ 1, 0, 0, 1], | |
| [ 0, 1, 0, 0], | |
| [ 0, 0, 1, 6], | |
| [ 0, 0, 0, 1]]) | |
| # render scene | |
| r = pyrender.OffscreenRenderer(1024, 1024) | |
| color, _ = r.render(scene) | |
| return color | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| # trimesh, reference, prediction = compute_inference(0) | |
| gr.Markdown(_HEADER_) | |
| gr.Markdown(_HEADER_2) | |
| with gr.Row(variant="panel"): | |
| with gr.Column(): | |
| d1 = gr.Slider(0, nb_samples-1, value=0, label="Training sample id", info="Choose between 0 and "+str(nb_samples-1)) | |
| # output1 = gr.Text(label="Inference status") | |
| output4 = gr.Text(label="Information on sample") | |
| output5 = gr.Image(label="Error") | |
| with gr.Column(): | |
| d2 = gr.Dropdown(out_fields_names, value=out_fields_names[0], label="Field name") | |
| output2 = gr.Image(label="Reference") | |
| output3 = gr.Image(label="MMGP prediction") | |
| # d1.input(compute_inference, [d1], [output1, output4]) | |
| d1.input(compute_inference, [d1], [output4]) | |
| d2.input(show_ref, [d2], [output2]) | |
| d2.input(show_pred, [d2], [output3]) | |
| d2.input(show_err, [d2], [output5]) | |
| demo.launch() | |