// SPDX-License-Identifier: MIT pragma solidity ^0.7.2; pragma experimental ABIEncoderV2; /******************************************************************************\ * Author: Nick Mudge * * Implementation of Diamond facet. /******************************************************************************/ import "./IDiamondCut.sol"; library LibDiamond { bytes32 constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage"); struct FacetAddressAndSelectorPosition { address facetAddress; uint16 selectorPosition; } struct DiamondStorage { // function selector => facet address and selector position in selectors array mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition; bytes4[] selectors; mapping(bytes4 => bool) supportedInterfaces; // owner of the contract address contractOwner; } function diamondStorage() internal pure returns (DiamondStorage storage ds) { bytes32 position = DIAMOND_STORAGE_POSITION; assembly { ds.slot := position } } event DiamondCut(IDiamondCut.FacetCut[] diamondCut, address init, bytes _calldata); // Internal function version of diamondCut // This code is almost the same as the external diamondCut, // except it is using 'Facet[] memory _diamondCut' instead of // 'Facet[] calldata _diamondCut'. // The code is duplicated to prevent copying calldata to memory which // causes an error for a two dimensional array. function diamondCut( IDiamondCut.FacetCut[] memory _diamondCut, address _init, bytes memory _calldata ) internal { uint256 selectorCount = diamondStorage().selectors.length; for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) { selectorCount = addReplaceRemoveFacetSelectors( selectorCount, _diamondCut[facetIndex].facetAddress, _diamondCut[facetIndex].action, _diamondCut[facetIndex].functionSelectors ); } emit DiamondCut(_diamondCut, _init, _calldata); initializeDiamondCut(_init, _calldata); } function addReplaceRemoveFacetSelectors( uint256 _selectorCount, address _newFacetAddress, IDiamondCut.FacetCutAction _action, bytes4[] memory _selectors ) internal returns (uint256) { DiamondStorage storage ds = diamondStorage(); require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut"); if (_action == IDiamondCut.FacetCutAction.Add) { require(_newFacetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)"); enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Add facet has no code"); for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) { bytes4 selector = _selectors[selectorIndex]; address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress; require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists"); ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition( _newFacetAddress, uint16(_selectorCount) ); ds.selectors.push(selector); _selectorCount++; } } else if(_action == IDiamondCut.FacetCutAction.Replace) { require(_newFacetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)"); enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Replace facet has no code"); for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) { bytes4 selector = _selectors[selectorIndex]; address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress; // only useful if immutable functions exist require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function"); require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function"); require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist"); // replace old facet address ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress; } } else if(_action == IDiamondCut.FacetCutAction.Remove) { require(_newFacetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)"); for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) { bytes4 selector = _selectors[selectorIndex]; FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector]; require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist"); // only useful if immutable functions exist require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function."); // replace selector with last selector if (oldFacetAddressAndSelectorPosition.selectorPosition != _selectorCount - 1) { bytes4 lastSelector = ds.selectors[_selectorCount - 1]; ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector; ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition; } // delete last selector ds.selectors.pop(); delete ds.facetAddressAndSelectorPosition[selector]; _selectorCount--; } } else { revert("LibDiamondCut: Incorrect FacetCutAction"); } return _selectorCount; } function initializeDiamondCut(address _init, bytes memory _calldata) internal { if (_init == address(0)) { require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty"); } else { require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)"); if (_init != address(this)) { enforceHasContractCode(_init, "LibDiamondCut: _init address has no code"); } (bool success, bytes memory error) = _init.delegatecall(_calldata); if (!success) { if (error.length > 0) { // bubble up the error revert(string(error)); } else { revert("LibDiamondCut: _init function reverted"); } } } } function enforceHasContractCode(address _contract, string memory _errorMessage) internal view { uint256 contractSize; assembly { contractSize := extcodesize(_contract) } require(contractSize > 0, _errorMessage); } }