Source code for qnexus.client.jobs._compile

"""Client API for compilation in Nexus."""

from typing import Union, cast

from pytket.backends.status import StatusEnum
from quantinuum_schemas.models.hypertket_config import HyperTketConfig

import qnexus.exceptions as qnx_exc
from qnexus.client import circuits as circuit_api
from qnexus.client import get_nexus_client
from qnexus.client.utils import accept_circuits_for_programs
from qnexus.context import get_active_project, merge_properties_from_context
from qnexus.models import BackendConfig
from qnexus.models.annotations import Annotations, CreateAnnotations, PropertiesDict
from qnexus.models.references import (
    CircuitRef,
    CompilationPassRef,
    CompilationResultRef,
    CompileJobRef,
    DataframableList,
    JobType,
    ProjectRef,
)


[docs] @accept_circuits_for_programs @merge_properties_from_context def start_compile_job( programs: Union[CircuitRef, list[CircuitRef]], backend_config: BackendConfig, name: str, description: str = "", project: ProjectRef | None = None, properties: PropertiesDict | None = None, optimisation_level: int = 2, credential_name: str | None = None, user_group: str | None = None, hypertket_config: HyperTketConfig | None = None, skip_intermediate_circuits: bool = True, ) -> CompileJobRef: """Submit a compile job to be run in Nexus.""" project = project or get_active_project(project_required=True) project = cast(ProjectRef, project) match programs: case CircuitRef(): program_ids = [str(programs.id)] case list(): if not all(isinstance(p, CircuitRef) for p in programs): raise TypeError("Compile jobs only accept circuits") program_ids = [str(p.id) for p in programs] case _: raise TypeError( "Expected programs to be either a CircuitRef or a list of CircuitRefs" ) attributes_dict = CreateAnnotations( name=name, description=description, properties=properties, ).model_dump(exclude_none=True) attributes_dict.update( { "job_type": "compile", "definition": { "job_definition_type": "compile_job_definition", "backend_config": backend_config.model_dump(), "user_group": user_group, "hypertket_config": hypertket_config.model_dump() if hypertket_config is not None else None, "optimisation_level": optimisation_level, "credential_name": credential_name, "items": [ { "program_id": program_id, } for program_id in program_ids ], "skip_store_intermediate_passes": skip_intermediate_circuits, }, } ) relationships = { "project": {"data": {"id": str(project.id), "type": "project"}}, } req_dict = { "data": { "attributes": attributes_dict, "relationships": relationships, "type": "job", } } resp = get_nexus_client().post( "/api/jobs/v1beta3", json=req_dict, ) if resp.status_code != 202: raise qnx_exc.ResourceCreateFailed( message=resp.text, status_code=resp.status_code ) res_data_dict = resp.json()["data"] return CompileJobRef( id=res_data_dict["id"], annotations=Annotations.from_dict(res_data_dict["attributes"]), job_type=JobType.COMPILE, last_status=StatusEnum.SUBMITTED, last_message="", project=project, backend_config_store=backend_config, )
def _results( compile_job: CompileJobRef, allow_incomplete: bool = False, ) -> DataframableList[CompilationResultRef]: """Get the results from a compile job.""" resp = get_nexus_client().get(f"/api/jobs/v1beta3/{compile_job.id}") if resp.status_code != 200: raise qnx_exc.ResourceFetchFailed( message=resp.text, status_code=resp.status_code ) resp_data = resp.json()["data"] job_status = resp_data["attributes"]["status"]["status"] if job_status != "COMPLETED" and not allow_incomplete: raise qnx_exc.ResourceFetchFailed(message=f"Job status: {job_status}") compilation_ids = [ item["compilation_id"] for item in resp_data["attributes"]["definition"]["items"] if item["status"]["status"] == "COMPLETED" ] compilation_refs: DataframableList[CompilationResultRef] = DataframableList([]) for compilation_id in compilation_ids: comp_record_resp = get_nexus_client().get( f"/api/compilations/v1beta2/{compilation_id}", ) if comp_record_resp.status_code != 200: raise qnx_exc.ResourceFetchFailed( message=comp_record_resp.text, status_code=comp_record_resp.status_code ) comp_json = comp_record_resp.json() project_id = comp_json["data"]["relationships"]["project"]["data"]["id"] project_details = next( proj for proj in comp_json["included"] if proj["id"] == project_id ) project = ProjectRef( id=project_id, annotations=Annotations.from_dict(project_details["attributes"]), contents_modified=project_details["attributes"]["contents_modified"], archived=project_details["attributes"]["archived"], ) compilation_refs.append( CompilationResultRef( id=comp_json["data"]["id"], annotations=Annotations.from_dict(comp_json["data"]["attributes"]), project=project, ) ) return compilation_refs def _fetch_compilation_output( compilation_result_ref: CompilationResultRef, ) -> tuple[CircuitRef, CircuitRef]: """Get the input/output compiled circuit from a compilation job.""" resp = get_nexus_client().get( f"/api/compilations/v1beta2/{compilation_result_ref.id}" ) if resp.status_code != 200: raise qnx_exc.ResourceFetchFailed( message=resp.text, status_code=resp.status_code ) res_dict = resp.json() project_id = res_dict["data"]["relationships"]["project"]["data"]["id"] project_details = next( proj for proj in res_dict["included"] if proj["id"] == project_id ) project = ProjectRef( id=project_id, annotations=Annotations.from_dict(project_details["attributes"]), contents_modified=project_details["attributes"]["contents_modified"], archived=project_details["attributes"]["archived"], ) compiled_circuit_id = res_dict["data"]["relationships"]["compiled_circuit"]["data"][ "id" ] compiled_circuit_details = next( item for item in res_dict["included"] if item["id"] == compiled_circuit_id ) compiled_circuit_ref = CircuitRef( id=compiled_circuit_id, annotations=Annotations.from_dict(compiled_circuit_details["attributes"]), project=project, ) input_circuit_id = res_dict["data"]["relationships"]["original_circuit"]["data"][ "id" ] input_circuit_details = next( item for item in res_dict["included"] if item["id"] == input_circuit_id ) input_circuit_ref = CircuitRef( id=input_circuit_id, annotations=Annotations.from_dict(input_circuit_details["attributes"]), project=project, ) return input_circuit_ref, compiled_circuit_ref def _fetch_compilation_passes( compilation_result_ref: CompilationResultRef, ) -> DataframableList[CompilationPassRef]: """Get summary information on the passes from a compile job.""" params = {"filter[compilation][id]": str(compilation_result_ref.id)} resp = get_nexus_client().get("/api/compilation_passes/v1beta2", params=params) if resp.status_code != 200: raise qnx_exc.ResourceFetchFailed( message=resp.text, status_code=resp.status_code ) pass_json = resp.json() pass_list: DataframableList[CompilationPassRef] = DataframableList([]) for pass_info in pass_json["data"]: pass_name = pass_info["attributes"]["pass_name"] pass_input_circuit_id = pass_info["relationships"]["original_circuit"]["data"][ "id" ] pass_input_circuit = circuit_api._fetch_by_id( pass_input_circuit_id, scope=None, ) pass_output_circuit_id = pass_info["relationships"]["compiled_circuit"]["data"][ "id" ] pass_output_circuit = circuit_api._fetch_by_id( pass_output_circuit_id, scope=None, ) pass_list.append( CompilationPassRef( pass_name=pass_name, input_circuit=pass_input_circuit, output_circuit=pass_output_circuit, id=pass_info["id"], ) ) return pass_list