Skip to content

Commit

Permalink
Fixing style and type checking issues
Browse files Browse the repository at this point in the history
  • Loading branch information
pediejo committed Dec 13, 2024
1 parent f831339 commit 32e7a7f
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def generate_minimal_code_distance_resources(
)
logical_architecture_resource_info.data_and_bus_code_distance = None
logical_architecture_resource_info.qec_cycle_allocation = None
logical_architecture_resource_info.logical_failure_rate_info = None

return logical_architecture_resource_info

Expand Down Expand Up @@ -155,7 +156,7 @@ def _generate_data_and_distillation_counts(
else:
if magic_state_factory is None:
raise ValueError(
"magic_state_factory cannot be None when the program requires T gates"
"magic_state_factory cannot be None when T gates are needed"
)

if optimization == "Space":
Expand Down Expand Up @@ -243,7 +244,7 @@ def get_qec_cycle_allocation(

if msf is None:
raise ValueError(
"magic_state_factory cannot be None when the program requires T gates"
"magic_state_factory cannot be None when T gates are needed"
)

distillation_time_in_cycles = msf.distillation_time_in_cycles
Expand Down
7 changes: 4 additions & 3 deletions src/benchq/magic_state_distillation/factory_selection.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Iterable
from typing import Iterable, Union
from ..resource_estimators.resource_info import MagicStateFactoryInfo
from decimal import Decimal


def find_optimal_factory(
per_t_gate_failure_tolerance: float,
per_t_gate_failure_tolerance: Union[float, Decimal],
magic_state_factory_iterator: Iterable[MagicStateFactoryInfo],
optimization: str = "Time",
) -> MagicStateFactoryInfo:
) -> Union[MagicStateFactoryInfo, None]:
"""Find the optimal factory from a given iterator of factories based on
the optimization criteria.
Args:
Expand Down
4 changes: 2 additions & 2 deletions src/benchq/resource_estimators/default_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

def get_precise_graph_estimate(
algorithm_implementation: AlgorithmImplementation,
logical_architecture_model: LogicalArchitectureModel,
logical_architecture_model: GraphBasedLogicalArchitectureModel,
hardware_model: BasicArchitectureModel,
optimization: str,
decoder_model: Optional[DecoderModel] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_precise_graph_estimate(

def get_fast_graph_estimate(
algorithm_implementation: AlgorithmImplementation,
logical_architecture_model: LogicalArchitectureModel,
logical_architecture_model: GraphBasedLogicalArchitectureModel,
hardware_model: BasicArchitectureModel,
optimization: str,
decoder_model: Optional[DecoderModel] = None,
Expand Down
32 changes: 22 additions & 10 deletions src/benchq/resource_estimators/graph_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,16 @@ def estimate_resources_from_compiled_implementation(
# If there are no T gates or rotation gates, then there is
# no need for a magic state factory
n_t_gates_per_rotation = 0
per_rotation_failure_tolerance = 0
per_t_gate_failure_tolerance = 0
per_rotation_failure_tolerance = Decimal(0)
per_t_gate_failure_tolerance = Decimal(0)
magic_state_factory = None
n_t_states = 0
else:
if compiled_implementation.program.n_rotation_gates == 0:
# If there are no rotation gates, then no T gates for
# rotations are needed
n_t_gates_per_rotation = 0
per_rotation_failure_tolerance = 0
per_rotation_failure_tolerance = Decimal(0)
n_t_states = compiled_implementation.program.n_t_gates
else:
per_rotation_failure_tolerance = Decimal(
Expand Down Expand Up @@ -150,9 +150,13 @@ def estimate_resources_from_compiled_implementation(

# Populate resource info

# Check that the qec_cycle_allocation was set
if log_arch_info.qec_cycle_allocation is None:
raise ValueError("log_arch_info.qec_cycle_allocation was not set.")

# Compute runtime to execute a single circuit
time_per_circuit_in_seconds = (
log_arch_info.qec_cycle_allocation.total # type: ignore
log_arch_info.qec_cycle_allocation.total
* hw_model.surface_code_cycle_time_in_seconds
)

Expand All @@ -170,29 +174,37 @@ def estimate_resources_from_compiled_implementation(

# Populate remaining logical failure rates

# Check that the logical failure rate info object was set
if log_arch_info.logical_failure_rate_info is None:
raise ValueError("log_arch_info.logical_failure_rate_info was not set.")

# Rotations
log_arch_info.logical_failure_rate_info.per_rotation_failure_rate = ( # type: ignore
log_arch_info.logical_failure_rate_info.per_rotation_failure_rate = (
per_rotation_failure_tolerance
)
log_arch_info.logical_failure_rate_info.total_rotation_failure_rate = ( # type: ignore
log_arch_info.logical_failure_rate_info.total_rotation_failure_rate = (
per_rotation_failure_tolerance
* compiled_implementation.program.n_rotation_gates
)

# Distillation
if magic_state_factory is None:
log_arch_info.logical_failure_rate_info.per_t_gate_failure_rate = 0.0 # type: ignore
log_arch_info.logical_failure_rate_info.total_distillation_failure_rate = ( # type: ignore
log_arch_info.logical_failure_rate_info.per_t_gate_failure_rate = 0.0
log_arch_info.logical_failure_rate_info.total_distillation_failure_rate = (
0.0
)
else:
log_arch_info.logical_failure_rate_info.per_t_gate_failure_rate = ( # type: ignore
log_arch_info.logical_failure_rate_info.per_t_gate_failure_rate = (
magic_state_factory.distilled_magic_state_error_rate
)
log_arch_info.logical_failure_rate_info.total_distillation_failure_rate = ( # type: ignore
log_arch_info.logical_failure_rate_info.total_distillation_failure_rate = (
magic_state_factory.distilled_magic_state_error_rate * n_t_states
)

# Check that the code distance was set
if log_arch_info.data_and_bus_code_distance is None:
raise ValueError("log_arch_info.data_and_bus_code_distance was not set.")

# Populate decoder resource info
decoder_info = get_decoder_info(
hw_model,
Expand Down
16 changes: 9 additions & 7 deletions src/benchq/resource_estimators/resource_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""Data structures describing estimated resources and related info."""

from dataclasses import dataclass, field
from typing import Generic, Optional, Tuple, TypeVar
from typing import Generic, Optional, Tuple, TypeVar, Union
from decimal import Decimal


from benchq.compilation.graph_states.compiled_data_structures import (
CompiledAlgorithmImplementation,
Expand Down Expand Up @@ -65,12 +67,12 @@ class DetailedIonTrapArchitectureResourceInfo:
class LogicalFailureRateInfo:
"""Logical failure rates for various processes."""

total_rotation_failure_rate: Optional[float] = None
total_distillation_failure_rate: Optional[float] = None
total_qec_failure_rate: Optional[float] = None
per_rotation_failure_rate: Optional[float] = None
per_t_gate_failure_rate: Optional[float] = None
per_qec_failure_rate: Optional[float] = None
total_rotation_failure_rate: Union[None, float, Decimal] = None
total_distillation_failure_rate: Union[None, float, Decimal] = None
total_qec_failure_rate: Union[None, float, Decimal] = None
per_rotation_failure_rate: Union[None, float, Decimal] = None
per_t_gate_failure_rate: Union[None, float, Decimal] = None
per_qec_failure_rate: Union[None, float, Decimal] = None

@property
def total_circuit_failure_rate(self) -> float:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_get_resource_estimations_for_program_accounts_for_spatial_resources(

magic_state_factory = MagicStateFactoryInfo(
"dummy_msf",
None,
None,
None,
None,
None, # type: ignore
None, # type: ignore
None, # type: ignore
None, # type: ignore
1,
)
data_and_bus_code_distance = None
Expand Down Expand Up @@ -204,9 +204,9 @@ def calculate_subroutine_sequence(x):

magic_state_factory = MagicStateFactoryInfo(
"dummy_msf",
None,
None,
None,
None, # type: ignore
None, # type: ignore
None, # type: ignore
distillation_time_in_cycles,
t_gates_per_distillation,
)
Expand Down
2 changes: 1 addition & 1 deletion tests/benchq/mlflow/test_data_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test__flatten_dict(input_dict, expected):
@patch("benchq.mlflow.data_logging.mlflow", autospec=True)
def test_log_input_objects_to_mlflow(mock_mlflow):
# Given
test_algo_descrip = AlgorithmImplementation(None, None, 10)
test_algo_descrip = AlgorithmImplementation(None, None, 10) # type: ignore

test_hardware_model = IONTrapModel(0.001, 0.1)

Expand Down
27 changes: 16 additions & 11 deletions tests/benchq/resource_estimators/test_graph_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,12 +186,14 @@ def test_get_resource_estimations_for_program_gives_correct_results(
architecture_model,
)

obtained_log_arch_info = resource_estimates.logical_architecture_resource_info

assert (
resource_estimates.logical_architecture_resource_info.data_and_bus_code_distance
obtained_log_arch_info.data_and_bus_code_distance # type: ignore
== expected_results["code_distance"]
)
assert (
resource_estimates.logical_architecture_resource_info.num_logical_qubits
obtained_log_arch_info.num_logical_qubits # type: ignore
== expected_results["n_logical_qubits"]
)

Expand Down Expand Up @@ -234,16 +236,16 @@ def test_better_hardware_architecture_does_not_require_more_resources(
high_noise_re = high_noise_resource_estimates.logical_architecture_resource_info

assert (
low_noise_resource_estimates.n_physical_qubits
<= high_noise_resource_estimates.n_physical_qubits
low_noise_resource_estimates.n_physical_qubits # type: ignore
<= high_noise_resource_estimates.n_physical_qubits # type: ignore
)
assert (
low_noise_re.data_and_bus_code_distance
<= high_noise_re.data_and_bus_code_distance
low_noise_re.data_and_bus_code_distance # type: ignore
<= high_noise_re.data_and_bus_code_distance # type: ignore
)
assert (
low_noise_resource_estimates.total_time_in_seconds
<= high_noise_resource_estimates.total_time_in_seconds
low_noise_resource_estimates.total_time_in_seconds # type: ignore
<= high_noise_resource_estimates.total_time_in_seconds # type: ignore
)


Expand Down Expand Up @@ -298,12 +300,15 @@ def test_higher_error_budget_does_not_require_more_resources(
low_re = low_error_resource_estimates.logical_architecture_resource_info

assert (
high_error_resource_estimates.n_physical_qubits
high_error_resource_estimates.n_physical_qubits # type: ignore
<= low_error_resource_estimates.n_physical_qubits
)
assert high_re.data_and_bus_code_distance <= low_re.data_and_bus_code_distance

high_err_distance = high_re.data_and_bus_code_distance # type: ignore
low_err_distance = low_re.data_and_bus_code_distance # type: ignore
assert high_err_distance <= low_err_distance # type: ignore
assert (
high_error_resource_estimates.total_time_in_seconds
high_error_resource_estimates.total_time_in_seconds # type: ignore
<= low_error_resource_estimates.total_time_in_seconds
)

Expand Down

0 comments on commit 32e7a7f

Please sign in to comment.