Skip to content

Commit

Permalink
Merge pull request #592 from spectrocloud/NSGUpdateIssue
Browse files Browse the repository at this point in the history
Update NSG only if default rules are not present, or else skip the update
  • Loading branch information
k8s-ci-robot authored May 12, 2020
2 parents 2f3e9dd + 5d80ea5 commit a2760be
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 49 deletions.
18 changes: 16 additions & 2 deletions cloud/services/securitygroups/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package securitygroups

import (
"context"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-06-01/network"
"github.com/Azure/go-autorest/autorest"
azure "sigs.k8s.io/cluster-api-provider-azure/cloud"
Expand Down Expand Up @@ -59,10 +58,25 @@ func (ac *AzureClient) Get(ctx context.Context, resourceGroupName, sgName string

// CreateOrUpdate creates or updates a network security group in the specified resource group.
func (ac *AzureClient) CreateOrUpdate(ctx context.Context, resourceGroupName string, sgName string, sg network.SecurityGroup) error {
future, err := ac.securitygroups.CreateOrUpdate(ctx, resourceGroupName, sgName, sg)
var etag string
if sg.Etag != nil {
etag = *sg.Etag
}
req, err := ac.securitygroups.CreateOrUpdatePreparer(ctx, resourceGroupName, sgName, sg)
if err != nil {
err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", nil, "Failure preparing request")
return err
}
if etag != "" {
req.Header.Add("If-Match", etag)
}

future, err := ac.securitygroups.CreateOrUpdateSender(req)
if err != nil {
err = autorest.NewErrorWithError(err, "network.SecurityGroupsClient", "CreateOrUpdate", future.Response(), "Failure sending request")
return err
}

err = future.WaitForCompletionRef(ctx, ac.securitygroups.Client)
if err != nil {
return err
Expand Down
132 changes: 92 additions & 40 deletions cloud/services/securitygroups/securitygroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package securitygroups
import (
"context"
"strconv"
"strings"

"github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-06-01/network"
"github.com/Azure/go-autorest/autorest/to"
Expand All @@ -27,6 +28,11 @@ import (
azure "sigs.k8s.io/cluster-api-provider-azure/cloud"
)

const (
apiServerRule = "apiServerRule"
sshRule = "sshRule"
)

// Spec specification for network security groups
type Spec struct {
Name string
Expand Down Expand Up @@ -59,52 +65,59 @@ func (s *Service) Reconcile(ctx context.Context, spec interface{}) error {
return errors.New("invalid security groups specification")
}

securityRules := &[]network.SecurityRule{}
securityGroup, err := s.Client.Get(ctx, s.Scope.ResourceGroup(), nsgSpec.Name)
if err != nil && !azure.ResourceNotFound(err) {
return errors.Wrapf(err, "failed to get NSG %s in %s", nsgSpec.Name, s.Scope.ResourceGroup())
}

nsgExists := false
securityRules := make([]network.SecurityRule, 0)
if securityGroup.Name != nil {
nsgExists = true
securityRules = *securityGroup.SecurityRules
}

defaultRules := make(map[string]network.SecurityRule, 0)
defaultRules[sshRule] = getRule("allow_ssh", "22", 100)
defaultRules[apiServerRule] = getRule("allow_6443", strconv.Itoa(int(s.Scope.APIServerPort())), 101)

if nsgSpec.IsControlPlane {
klog.V(2).Infof("using additional rules for control plane %s", nsgSpec.Name)
securityRules = &[]network.SecurityRule{
{
Name: to.StringPtr("allow_ssh"),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocolTCP,
SourceAddressPrefix: to.StringPtr("*"),
SourcePortRange: to.StringPtr("*"),
DestinationAddressPrefix: to.StringPtr("*"),
DestinationPortRange: to.StringPtr("22"),
Access: network.SecurityRuleAccessAllow,
Direction: network.SecurityRuleDirectionInbound,
Priority: to.Int32Ptr(100),
},
},
{
Name: to.StringPtr("allow_6443"),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocolTCP,
SourceAddressPrefix: to.StringPtr("*"),
SourcePortRange: to.StringPtr("*"),
DestinationAddressPrefix: to.StringPtr("*"),
DestinationPortRange: to.StringPtr(strconv.Itoa(int(s.Scope.APIServerPort()))),
Access: network.SecurityRuleAccessAllow,
Direction: network.SecurityRuleDirectionInbound,
Priority: to.Int32Ptr(101),
},
},
if nsgExists {
// Check if the expected rules are present
update := false
for _, rule := range defaultRules {
if !ruleExists(securityRules, rule) {
update = true
securityRules = append(securityRules, rule)
}
}
if !update {
// Skip update for control-plane NSG as the required default rules are present
klog.V(2).Infof("security group %s exists and no default rules are missing, skipping update", nsgSpec.Name)
return nil
}
} else {
klog.V(2).Infof("applying missing default rules for control plane NSG %s", nsgSpec.Name)
securityRules = append(securityRules, defaultRules[sshRule], defaultRules[apiServerRule])
}
} else if nsgExists {
// Skip update for node NSG as no default rules are required
klog.V(2).Infof("security group %s exists and no default rules are required, skipping update", nsgSpec.Name)
return nil
}

klog.V(2).Infof("creating security group %s", nsgSpec.Name)
err := s.Client.CreateOrUpdate(
ctx,
s.Scope.ResourceGroup(),
nsgSpec.Name,
network.SecurityGroup{
Location: to.StringPtr(s.Scope.Location()),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: securityRules,
},
sg := network.SecurityGroup{
Location: to.StringPtr(s.Scope.Location()),
SecurityGroupPropertiesFormat: &network.SecurityGroupPropertiesFormat{
SecurityRules: &securityRules,
},
)
}
if nsgExists {
// We append the existing NSG etag to the header to ensure we only apply the updates if the NSG has not been modified.
sg.Etag = securityGroup.Etag
}
klog.V(2).Infof("creating security group %s", nsgSpec.Name)
err = s.Client.CreateOrUpdate(ctx, s.Scope.ResourceGroup(), nsgSpec.Name, sg)
if err != nil {
return errors.Wrapf(err, "failed to create security group %s in resource group %s", nsgSpec.Name, s.Scope.ResourceGroup())
}
Expand All @@ -113,6 +126,45 @@ func (s *Service) Reconcile(ctx context.Context, spec interface{}) error {
return err
}

func ruleExists(rules []network.SecurityRule, rule network.SecurityRule) bool {
for _, existingRule := range rules {
if !strings.EqualFold(to.String(existingRule.Name), to.String(rule.Name)) {
continue
}
if !strings.EqualFold(to.String(existingRule.DestinationPortRange), to.String(rule.DestinationPortRange)) {
continue
}
if existingRule.Protocol != network.SecurityRuleProtocolTCP &&
existingRule.Access != network.SecurityRuleAccessAllow &&
existingRule.Direction != network.SecurityRuleDirectionInbound {
continue
}
if !strings.EqualFold(to.String(existingRule.SourcePortRange), "*") &&
!strings.EqualFold(to.String(existingRule.SourceAddressPrefix), "*") &&
!strings.EqualFold(to.String(existingRule.DestinationAddressPrefix), "*") {
continue
}
return true
}
return false
}

func getRule(name, destinationPort string, priority int32) network.SecurityRule {
return network.SecurityRule{
Name: to.StringPtr(name),
SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{
Protocol: network.SecurityRuleProtocolTCP,
SourceAddressPrefix: to.StringPtr("*"),
SourcePortRange: to.StringPtr("*"),
DestinationAddressPrefix: to.StringPtr("*"),
DestinationPortRange: to.StringPtr(destinationPort),
Access: network.SecurityRuleAccessAllow,
Direction: network.SecurityRuleDirectionInbound,
Priority: to.Int32Ptr(priority),
},
}
}

// Delete deletes the network security group with the provided name.
func (s *Service) Delete(ctx context.Context, spec interface{}) error {
nsgSpec, ok := spec.(*Spec)
Expand Down
16 changes: 9 additions & 7 deletions cloud/services/securitygroups/securitygroups_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,32 @@ func TestReconcileSecurityGroups(t *testing.T) {
sgName string
isControlPlane bool
vnetSpec *infrav1.VnetSpec
expect func(m *mock_securitygroups.MockClientMockRecorder)
expect func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder)
}{
{
name: "security group does not exists",
sgName: "my-sg",
isControlPlane: true,
vnetSpec: &infrav1.VnetSpec{},
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
m.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {
m.Get(context.TODO(), "my-rg", "my-sg")
m1.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
},
}, {
name: "security group does not exist and it's not for a control plane",
sgName: "my-sg",
isControlPlane: false,
vnetSpec: &infrav1.VnetSpec{},
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
m.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {
m.Get(context.TODO(), "my-rg", "my-sg")
m1.CreateOrUpdate(context.TODO(), "my-rg", "my-sg", gomock.AssignableToTypeOf(network.SecurityGroup{}))
},
}, {
name: "skipping network security group reconcile in custom vnet mode",
sgName: "my-sg",
isControlPlane: false,
vnetSpec: &infrav1.VnetSpec{ResourceGroup: "custom-vnet-rg", Name: "custom-vnet", ID: "id1"},
expect: func(m *mock_securitygroups.MockClientMockRecorder) {
expect: func(m *mock_securitygroups.MockClientMockRecorder, m1 *mock_securitygroups.MockClientMockRecorder) {

},
},
Expand All @@ -91,7 +93,7 @@ func TestReconcileSecurityGroups(t *testing.T) {

client := fake.NewFakeClient(cluster)

tc.expect(sgMock.EXPECT())
tc.expect(sgMock.EXPECT(), sgMock.EXPECT())

clusterScope, err := scope.NewClusterScope(scope.ClusterScopeParams{
AzureClients: scope.AzureClients{
Expand Down

0 comments on commit a2760be

Please sign in to comment.