149 lines
7.2 KiB
Solidity
149 lines
7.2 KiB
Solidity
// SPDX-License-Identifier: MIT
|
|
pragma solidity ^0.7.2;
|
|
pragma experimental ABIEncoderV2;
|
|
|
|
/******************************************************************************\
|
|
* Author: Nick Mudge
|
|
*
|
|
* Implementation of Diamond facet.
|
|
/******************************************************************************/
|
|
import "./IDiamondCut.sol";
|
|
|
|
library LibDiamond {
|
|
bytes32 constant DIAMOND_STORAGE_POSITION = keccak256("diamond.standard.diamond.storage");
|
|
|
|
struct FacetAddressAndSelectorPosition {
|
|
address facetAddress;
|
|
uint16 selectorPosition;
|
|
}
|
|
|
|
struct DiamondStorage {
|
|
// function selector => facet address and selector position in selectors array
|
|
mapping(bytes4 => FacetAddressAndSelectorPosition) facetAddressAndSelectorPosition;
|
|
bytes4[] selectors;
|
|
mapping(bytes4 => bool) supportedInterfaces;
|
|
// owner of the contract
|
|
address contractOwner;
|
|
}
|
|
|
|
function diamondStorage() internal pure returns (DiamondStorage storage ds) {
|
|
bytes32 position = DIAMOND_STORAGE_POSITION;
|
|
assembly {
|
|
ds.slot := position
|
|
}
|
|
}
|
|
|
|
event DiamondCut(IDiamondCut.FacetCut[] diamondCut, address init, bytes _calldata);
|
|
|
|
// Internal function version of diamondCut
|
|
// This code is almost the same as the external diamondCut,
|
|
// except it is using 'Facet[] memory _diamondCut' instead of
|
|
// 'Facet[] calldata _diamondCut'.
|
|
// The code is duplicated to prevent copying calldata to memory which
|
|
// causes an error for a two dimensional array.
|
|
function diamondCut(
|
|
IDiamondCut.FacetCut[] memory _diamondCut,
|
|
address _init,
|
|
bytes memory _calldata
|
|
) internal {
|
|
uint256 selectorCount = diamondStorage().selectors.length;
|
|
for (uint256 facetIndex; facetIndex < _diamondCut.length; facetIndex++) {
|
|
selectorCount = addReplaceRemoveFacetSelectors(
|
|
selectorCount,
|
|
_diamondCut[facetIndex].facetAddress,
|
|
_diamondCut[facetIndex].action,
|
|
_diamondCut[facetIndex].functionSelectors
|
|
);
|
|
}
|
|
emit DiamondCut(_diamondCut, _init, _calldata);
|
|
initializeDiamondCut(_init, _calldata);
|
|
}
|
|
|
|
function addReplaceRemoveFacetSelectors(
|
|
uint256 _selectorCount,
|
|
address _newFacetAddress,
|
|
IDiamondCut.FacetCutAction _action,
|
|
bytes4[] memory _selectors
|
|
) internal returns (uint256) {
|
|
DiamondStorage storage ds = diamondStorage();
|
|
require(_selectors.length > 0, "LibDiamondCut: No selectors in facet to cut");
|
|
if (_action == IDiamondCut.FacetCutAction.Add) {
|
|
require(_newFacetAddress != address(0), "LibDiamondCut: Add facet can't be address(0)");
|
|
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Add facet has no code");
|
|
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
|
|
bytes4 selector = _selectors[selectorIndex];
|
|
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
|
|
require(oldFacetAddress == address(0), "LibDiamondCut: Can't add function that already exists");
|
|
ds.facetAddressAndSelectorPosition[selector] = FacetAddressAndSelectorPosition(
|
|
_newFacetAddress,
|
|
uint16(_selectorCount)
|
|
);
|
|
ds.selectors.push(selector);
|
|
_selectorCount++;
|
|
}
|
|
} else if(_action == IDiamondCut.FacetCutAction.Replace) {
|
|
require(_newFacetAddress != address(0), "LibDiamondCut: Replace facet can't be address(0)");
|
|
enforceHasContractCode(_newFacetAddress, "LibDiamondCut: Replace facet has no code");
|
|
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
|
|
bytes4 selector = _selectors[selectorIndex];
|
|
address oldFacetAddress = ds.facetAddressAndSelectorPosition[selector].facetAddress;
|
|
// only useful if immutable functions exist
|
|
require(oldFacetAddress != address(this), "LibDiamondCut: Can't replace immutable function");
|
|
require(oldFacetAddress != _newFacetAddress, "LibDiamondCut: Can't replace function with same function");
|
|
require(oldFacetAddress != address(0), "LibDiamondCut: Can't replace function that doesn't exist");
|
|
// replace old facet address
|
|
ds.facetAddressAndSelectorPosition[selector].facetAddress = _newFacetAddress;
|
|
}
|
|
} else if(_action == IDiamondCut.FacetCutAction.Remove) {
|
|
require(_newFacetAddress == address(0), "LibDiamondCut: Remove facet address must be address(0)");
|
|
for (uint256 selectorIndex; selectorIndex < _selectors.length; selectorIndex++) {
|
|
bytes4 selector = _selectors[selectorIndex];
|
|
FacetAddressAndSelectorPosition memory oldFacetAddressAndSelectorPosition = ds.facetAddressAndSelectorPosition[selector];
|
|
require(oldFacetAddressAndSelectorPosition.facetAddress != address(0), "LibDiamondCut: Can't remove function that doesn't exist");
|
|
// only useful if immutable functions exist
|
|
require(oldFacetAddressAndSelectorPosition.facetAddress != address(this), "LibDiamondCut: Can't remove immutable function.");
|
|
// replace selector with last selector
|
|
if (oldFacetAddressAndSelectorPosition.selectorPosition != _selectorCount - 1) {
|
|
bytes4 lastSelector = ds.selectors[_selectorCount - 1];
|
|
ds.selectors[oldFacetAddressAndSelectorPosition.selectorPosition] = lastSelector;
|
|
ds.facetAddressAndSelectorPosition[lastSelector].selectorPosition = oldFacetAddressAndSelectorPosition.selectorPosition;
|
|
}
|
|
// delete last selector
|
|
ds.selectors.pop();
|
|
delete ds.facetAddressAndSelectorPosition[selector];
|
|
_selectorCount--;
|
|
}
|
|
} else {
|
|
revert("LibDiamondCut: Incorrect FacetCutAction");
|
|
}
|
|
return _selectorCount;
|
|
}
|
|
|
|
function initializeDiamondCut(address _init, bytes memory _calldata) internal {
|
|
if (_init == address(0)) {
|
|
require(_calldata.length == 0, "LibDiamondCut: _init is address(0) but_calldata is not empty");
|
|
} else {
|
|
require(_calldata.length > 0, "LibDiamondCut: _calldata is empty but _init is not address(0)");
|
|
if (_init != address(this)) {
|
|
enforceHasContractCode(_init, "LibDiamondCut: _init address has no code");
|
|
}
|
|
(bool success, bytes memory error) = _init.delegatecall(_calldata);
|
|
if (!success) {
|
|
if (error.length > 0) {
|
|
// bubble up the error
|
|
revert(string(error));
|
|
} else {
|
|
revert("LibDiamondCut: _init function reverted");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
function enforceHasContractCode(address _contract, string memory _errorMessage) internal view {
|
|
uint256 contractSize;
|
|
assembly {
|
|
contractSize := extcodesize(_contract)
|
|
}
|
|
require(contractSize > 0, _errorMessage);
|
|
}
|
|
} |